├── .classpath
├── .gitignore
├── .project
├── .settings
├── org.eclipse.core.resources.prefs
└── org.scala-ide.sdt.core.prefs
├── LICENSE
├── MANIFEST.MF.prototype
├── README.md
├── build.properties
├── build.sbt
├── patch-ide.bash
├── project
├── build.properties
└── build.sbt
└── src
├── main
├── java
│ └── scala
│ │ └── tools
│ │ └── refactoring
│ │ └── JavadocStub.java
├── scala-2.10
│ ├── README
│ └── scala
│ │ └── tools
│ │ └── refactoring
│ │ ├── ScalaVersionsAdapter.scala
│ │ └── implementations
│ │ └── oimports
│ │ └── ImplicitValDefTraverserPF.scala
├── scala-2.11
│ ├── README
│ └── scala
│ │ └── tools
│ │ └── refactoring
│ │ ├── ScalaVersionsAdapter.scala
│ │ └── implementations
│ │ └── oimports
│ │ └── ImplicitValDefTraverserPF.scala
├── scala-2.12
│ ├── README
│ └── scala
│ │ └── tools
│ │ └── refactoring
│ │ ├── ScalaVersionsAdapter.scala
│ │ └── implementations
│ │ └── oimports
│ │ └── ImplicitValDefTraverserPF.scala
└── scala
│ └── scala
│ └── tools
│ └── refactoring
│ ├── MultiStageRefactoring.scala
│ ├── ParameterlessRefactoring.scala
│ ├── Refactoring.scala
│ ├── analysis
│ ├── CompilationUnitDependencies.scala
│ ├── CompilationUnitIndexes.scala
│ ├── GlobalIndexes.scala
│ ├── ImportAnalysis.scala
│ ├── ImportsToolbox.scala
│ ├── Indexes.scala
│ ├── NameValidation.scala
│ ├── PartiallyAppliedMethodsFinder.scala
│ ├── ScopeAnalysis.scala
│ ├── SymbolExpanders.scala
│ └── TreeAnalysis.scala
│ ├── common
│ ├── Change.scala
│ ├── CompilerAccess.scala
│ ├── CompilerApiExtensions.scala
│ ├── EnrichedTrees.scala
│ ├── InsertionPositions.scala
│ ├── InteractiveScalaCompiler.scala
│ ├── Occurrences.scala
│ ├── PositionDebugging.scala
│ ├── Selections.scala
│ ├── TracingHelpers.scala
│ ├── TreeExtractors.scala
│ ├── TreeTraverser.scala
│ ├── exceptions.scala
│ ├── package.scala
│ └── tracing.scala
│ ├── implementations
│ ├── AddField.scala
│ ├── AddImportStatement.scala
│ ├── AddMethod.scala
│ ├── AddValOrDef.scala
│ ├── ChangeParamOrder.scala
│ ├── ClassParameterDrivenSourceGeneration.scala
│ ├── ExpandCaseClassBinding.scala
│ ├── ExplicitGettersSetters.scala
│ ├── ExtractLocal.scala
│ ├── ExtractMethod.scala
│ ├── ExtractTrait.scala
│ ├── GenerateHashcodeAndEquals.scala
│ ├── ImportsHelper.scala
│ ├── InlineLocal.scala
│ ├── IntroduceProductNTrait.scala
│ ├── MarkOccurrences.scala
│ ├── MergeParameterLists.scala
│ ├── MethodSignatureRefactoring.scala
│ ├── MoveClass.scala
│ ├── MoveConstructorToCompanionObject.scala
│ ├── OrganizeImports.scala
│ ├── Rename.scala
│ ├── SplitParameterLists.scala
│ ├── UnusedImportsFinder.scala
│ ├── extraction
│ │ ├── ExtractCode.scala
│ │ ├── ExtractExtractor.scala
│ │ ├── ExtractMethod.scala
│ │ ├── ExtractParameter.scala
│ │ ├── ExtractValue.scala
│ │ └── ExtractionRefactoring.scala
│ └── oimports
│ │ ├── ImportParticipants.scala
│ │ ├── ImportsOrganizer.scala
│ │ ├── OrganizeImportsWorker.scala
│ │ ├── Region.scala
│ │ ├── RegionTransformations.scala
│ │ └── TreeToolbox.scala
│ ├── package.scala
│ ├── sourcegen
│ ├── AbstractPrinter.scala
│ ├── CommentsUtils.scala
│ ├── CommonPrintUtils.scala
│ ├── Formatting.scala
│ ├── Fragment.scala
│ ├── Indentation.scala
│ ├── Layout.scala
│ ├── LayoutHelper.scala
│ ├── PrettyPrinter.scala
│ ├── Requisite.scala
│ ├── ReusingPrinter.scala
│ ├── SourceGenerator.scala
│ ├── SourceUtils.scala
│ ├── TreeChangesDiscoverer.scala
│ └── TreePrintingTraversals.scala
│ ├── transformation
│ ├── TransformableSelections.scala
│ ├── Transformations.scala
│ ├── TreeFactory.scala
│ └── TreeTransformations.scala
│ └── util
│ ├── CompilerProvider.scala
│ ├── Memoized.scala
│ ├── SourceHelpers.scala
│ ├── SourceWithMarker.scala
│ ├── SourceWithSelection.scala
│ ├── UnionFind.scala
│ └── UniqueNames.scala
└── test
├── java
└── scala
│ └── tools
│ └── refactoring
│ ├── common
│ └── TracingHelpersTest.scala
│ └── tests
│ └── util
│ ├── ExceptionWrapper.java
│ ├── ScalaVersion.java
│ ├── ScalaVersionTestRule.java
│ └── TestRules.java
├── scala-2.10
├── README
└── scala
│ └── tools
│ └── refactoring
│ └── tests
│ └── implementations
│ └── imports
│ ├── OrganizeImportsScalaSpecificTests.scala
│ └── OrganizeImportsWithMacrosTest.scala
├── scala-2.11
├── README
└── scala
│ └── tools
│ └── refactoring
│ └── tests
│ └── implementations
│ └── imports
│ ├── OrganizeImportsScalaSpecificTests.scala
│ └── OrganizeImportsWithMacrosTest.scala
├── scala-2.12
├── README
└── scala
│ └── tools
│ └── refactoring
│ └── tests
│ └── implementations
│ └── imports
│ ├── OrganizeImportsScalaSpecificTests.scala
│ └── OrganizeImportsWithMacrosTest.scala
└── scala
└── scala
└── tools
└── refactoring
├── implementations
└── OrganizeImportsAlgosTest.scala
└── tests
├── RefactoringTestSuite.scala
├── analysis
├── CompilationUnitDependenciesTest.scala
├── DeclarationIndexTest.scala
├── FindShadowedTest.scala
├── ImportAnalysisTest.scala
├── MultipleFilesIndexTest.scala
├── NameValidationTest.scala
├── ScopeAnalysisTest.scala
└── TreeAnalysisTest.scala
├── common
├── EnrichedTreesTest.scala
├── InsertionPositionsTest.scala
├── OccurrencesTest.scala
├── SelectionDependenciesTest.scala
├── SelectionExpansionsTest.scala
├── SelectionPropertiesTest.scala
└── SelectionsTest.scala
├── implementations
├── AddFieldTest.scala
├── AddMethodTest.scala
├── ChangeParamOrderTest.scala
├── ExpandCaseClassBindingTest.scala
├── ExplicitGettersSettersTest.scala
├── ExtractLocalTest.scala
├── ExtractMethodTest.scala
├── ExtractTraitTest.scala
├── GenerateHashcodeAndEqualsTest.scala
├── InlineLocalTest.scala
├── IntroduceProductNTraitTest.scala
├── MarkOccurrencesTest.scala
├── MergeParameterListsTest.scala
├── MoveClassTest.scala
├── MoveConstructorToCompanionObjectTest.scala
├── RenameTest.scala
├── SplitParameterListsTest.scala
├── extraction
│ ├── ExtractCodeTest.scala
│ ├── ExtractExtractorTest.scala
│ ├── ExtractMethodTest.scala
│ ├── ExtractParameterTest.scala
│ ├── ExtractValueTest.scala
│ └── ExtractionsTest.scala
└── imports
│ ├── AddImportStatementTest.scala
│ ├── OrganizeImportsBaseTest.scala
│ ├── OrganizeImportsCollapseSelectorsToWildcardTest.scala
│ ├── OrganizeImportsEndOfLineTest.scala
│ ├── OrganizeImportsFullyRecomputeTest.scala
│ ├── OrganizeImportsGroupsTest.scala
│ ├── OrganizeImportsRecomputeAndModifyTest.scala
│ ├── OrganizeImportsTest.scala
│ ├── OrganizeImportsWildcardsTest.scala
│ ├── OrganizeMissingImportsTest.scala
│ ├── PrependOrDropScalaPackageFromRecomputedTest.scala
│ ├── PrependOrDropScalaPackageKeepTest.scala
│ └── UnusedImportsFinderTest.scala
├── sourcegen
├── CustomFormattingTest.scala
├── IndividualSourceGenTest.scala
├── LayoutTest.scala
├── PrettyPrinterTest.scala
├── ReusingPrinterTest.scala
├── SourceGenTest.scala
├── SourceHelperTest.scala
├── SourceUtilsTest.scala
└── TreeChangesDiscovererTest.scala
├── transformation
├── TransformableSelectionTest.scala
└── TreeTransformationsTest.scala
└── util
├── FreshCompilerForeachTest.scala
├── SourceHelpersTest.scala
├── SourceWithMarkerTest.scala
├── TestHelper.scala
├── TestRefactoring.scala
├── TextSelections.scala
├── TextSelectionsTest.scala
├── UnionFindInitTest.scala
└── UnionFindTest.scala
/.classpath:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | target
2 | bin
3 | .cache-main
4 | .cache-tests
5 | .attach_pid*
6 | **/*.tmpBin
7 | *.swp
8 | .ensime*
9 | .idea
10 | *.iml
11 |
--------------------------------------------------------------------------------
/.project:
--------------------------------------------------------------------------------
1 |
2 |
3 | org.scala-refactoring.library
4 |
5 |
6 |
7 |
8 |
9 | org.eclipse.pde.ManifestBuilder
10 |
11 |
12 |
13 |
14 | org.eclipse.pde.SchemaBuilder
15 |
16 |
17 |
18 |
19 | org.scala-ide.sdt.core.scalabuilder
20 |
21 |
22 |
23 |
24 | org.scalastyle.scalastyleplugin.core.ScalastyleBuilder
25 |
26 |
27 |
28 |
29 |
30 | org.scala-ide.sdt.core.scalanature
31 | org.eclipse.jdt.core.javanature
32 | org.eclipse.pde.PluginNature
33 | org.scalastyle.scalastyleplugin.core.ScalastyleNature
34 |
35 |
36 |
--------------------------------------------------------------------------------
/.settings/org.eclipse.core.resources.prefs:
--------------------------------------------------------------------------------
1 | #Generated by sbteclipse
2 | #Thu Oct 13 18:48:15 CEST 2016
3 | encoding/=UTF-8
4 |
--------------------------------------------------------------------------------
/.settings/org.scala-ide.sdt.core.prefs:
--------------------------------------------------------------------------------
1 | //src/main/scala=main
2 | //src/main/scala-2.12=main
3 | //src/test/java=tests
4 | //src/test/scala=tests
5 | //src/test/scala-2.12=tests
6 | P=
7 | Xcheckinit=false
8 | Xdisable-assertions=false
9 | Xelide-below=-2147483648
10 | Xexperimental=false
11 | Xfatal-warnings=true
12 | Xfuture=true
13 | Xlog-implicits=false
14 | Xno-uescape=false
15 | Xplugin=
16 | Xplugin-disable=
17 | Xplugin-require=
18 | Xpluginsdir=misc/scala-devel/plugins
19 | Ypresentation-debug=false
20 | Ypresentation-delay=0
21 | Ypresentation-log=
22 | Ypresentation-replay=
23 | Ypresentation-verbose=false
24 | Ywarn-dead-code=true
25 | apiDiff=false
26 | compileorder=Mixed
27 | deprecation=true
28 | eclipse.preferences.version=1
29 | explaintypes=false
30 | feature=true
31 | g=vars
32 | no-specialization=false
33 | nowarn=false
34 | optimise=false
35 | recompileOnMacroDef=true
36 | relationsDebug=false
37 | scala.compiler.additionalParams=-deprecation\:false -encoding UTF-8 -feature -language\:_ -Xlint\:-unused,_ -Yno-adapted-args -Ywarn-unused-import -Ywarn-unused\:imports,privates,locals,-patvars,-params,-implicits,_
38 | scala.compiler.installation=2.12
39 | scala.compiler.sourceLevel=2.12
40 | scala.compiler.useProjectSettings=true
41 | stopBuildOnError=true
42 | target=jvm-1.8
43 | unchecked=true
44 | useScopesCompiler=true
45 | verbose=false
46 | withVersionClasspathValidator=false
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | SCALA LICENSE
2 |
3 | Copyright (c) 2002-2010 EPFL, Lausanne, unless otherwise specified.
4 | All rights reserved.
5 |
6 | This software was developed by the Programming Methods Laboratory of the
7 | Swiss Federal Institute of Technology (EPFL), Lausanne, Switzerland.
8 |
9 | Permission to use, copy, modify, and distribute this software in source
10 | or binary form for any purpose with or without fee is hereby granted,
11 | provided that the following conditions are met:
12 |
13 | 1. Redistributions of source code must retain the above copyright
14 | notice, this list of conditions and the following disclaimer.
15 |
16 | 2. Redistributions in binary form must reproduce the above copyright
17 | notice, this list of conditions and the following disclaimer in the
18 | documentation and/or other materials provided with the distribution.
19 |
20 | 3. Neither the name of the EPFL nor the names of its contributors
21 | may be used to endorse or promote products derived from this
22 | software without specific prior written permission.
23 |
24 |
25 | THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND
26 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28 | ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
29 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
33 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
34 | OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
35 | SUCH DAMAGE.
36 |
--------------------------------------------------------------------------------
/MANIFEST.MF.prototype:
--------------------------------------------------------------------------------
1 | Manifest-Version: 1.0
2 | Bundle-ManifestVersion: 2
3 | Bundle-Name: Scala Refactoring
4 | Bundle-SymbolicName: org.scala-refactoring.library
5 | Bundle-Version: version.qualifier
6 | Require-Bundle: org.junit;bundle-version="4.11.0",
7 | org.scala-lang.scala-library
8 | Export-Package: scala.tools.refactoring,scala.tools.refactoring.analys
9 | is,scala.tools.refactoring.common,scala.tools.refactoring.implementat
10 | ions,scala.tools.refactoring.implementations.extraction,scala.tools.r
11 | efactoring.sourcegen,scala.tools.refactoring.transformation,scala.too
12 | ls.refactoring.util
13 | Bundle-ClassPath: .
14 | Import-Package: scala.reflect.internal;resolution:=optional,
15 | scala.reflect.internal.util;resolution:=optional,
16 | scala.reflect.api;resolution:=optional,
17 | scala.reflect.runtime;resolution:=optional,
18 | scala.reflect.macros;resolution:=optional,
19 | scala.reflect.io;resolution:=optional,
20 | scala.tools.nsc,
21 | scala.tools.nsc.ast,
22 | scala.tools.nsc.ast.parser,
23 | scala.tools.nsc.interactive,
24 | scala.tools.nsc.io,
25 | scala.tools.nsc.reporters,
26 | scala.tools.nsc.settings,
27 | scala.tools.nsc.symtab,
28 | scala.tools.nsc.typechecker,
29 | scala.tools.nsc.util
30 | Bundle-RequiredExecutionEnvironment: JavaSE-1.6
31 |
--------------------------------------------------------------------------------
/build.properties:
--------------------------------------------------------------------------------
1 | bin.includes = META-INF/,\
2 | .,\
3 | src/main/,\
4 | src/test/
5 | src.includes = src/
6 | jars.compile.order = .
7 | source.. = src/main/scala/,\
8 | src/test/scala/,\
9 | src/test/java/
10 | output.. = bin/
11 |
12 |
--------------------------------------------------------------------------------
/patch-ide.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | SCRIPT_NAME="./$(basename "$0")"
5 |
6 | showHelp() {
7 | echo "Patches ScalaIDE with the latest local build"
8 | echo "--------------------------------------------"
9 | echo ""
10 | echo "The script is controlled by the following environment variables:"
11 | echo " SCALA_IDE_HOME (mandatory):"
12 | echo " Path to you local ScalaIDE installation"
13 | echo ""
14 | echo "Examples: "
15 | echo " SCALA_IDE_HOME=\"/path/to/scala-ide\" $SCRIPT_NAME"
16 | echo " SCALA_IDE_HOME=\"/path/to/scala-ide\" KEEP_REFACTORING_LIBRARY_BACKUP=true $SCRIPT_NAME"
17 | echo ""
18 | echo "Best practice:"
19 | echo " If you use the script regularly, it is recommended to export"
20 | echo " an appropriate value for SCALA_IDE_HOME via your bashrc, so that you"
21 | echo " don't have to specify this setting repeatedly."
22 | echo ""
23 | echo "Warning:"
24 | echo " Note that patching the IDE like this only works as long as"
25 | echo " binary compatibility is maintained. Watch out for"
26 | echo " AbstractMethodErrors and the like."
27 | }
28 |
29 | showHelpAndDie() {
30 | showHelp
31 | exit 1
32 | }
33 |
34 | echoErr() {
35 | cat <<< "$@" 1>&2
36 | }
37 |
38 | KEEP_REFACTORING_LIBRARY_BACKUP=${KEEP_REFACTORING_LIBRARY_BACKUP:-true}
39 |
40 | if [[ -z "$SCALA_IDE_HOME" ]]; then
41 | showHelpAndDie
42 | fi
43 |
44 | SCALA_IDE_PLUGINS_DIR="$SCALA_IDE_HOME/plugins"
45 |
46 | if [[ ! -d "$SCALA_IDE_PLUGINS_DIR" || ! -w "$SCALA_IDE_PLUGINS_DIR" ]]; then
47 | echoErr "Invalid SCALA_IDE_HOME: $SCALA_IDE_PLUGINS_DIR is not a writable directory"
48 | exit 1
49 | fi
50 |
51 | TARGET_FOLDER="./target/scala-2.11/"
52 |
53 | shopt -s nullglob
54 | _newRefactoringJars=("$TARGET_FOLDER"*SNAPSHOT.jar)
55 | NEW_REFACTORING_JAR="${_newRefactoringJars[0]}"
56 |
57 | if [[ ! -f "$NEW_REFACTORING_JAR" || ! -r "$NEW_REFACTORING_JAR" ]]; then
58 | echoErr "Cannot find a build of the library in $TARGET_FOLDER"
59 | exit 1
60 | fi
61 |
62 | TSTAMP="$(date +%Y-%m-%dT%H-%M-%S)"
63 |
64 | _oldRefactoringJar=($SCALA_IDE_PLUGINS_DIR/org.scala-refactoring.library*.jar)
65 | if [[ ${#_oldRefactoringJar[@]} == 0 ]]; then
66 | echoErr "Cannot find the refactoring library in $SCALA_IDE_PLUGINS_DIR"
67 | exit 1
68 | elif [[ ${#_oldRefactoringJar[@]} -gt 1 ]]; then
69 | echoErr "Multiple copies of the refactoring library found in $SCALA_IDE_PLUGINS_DIR:"
70 | for jarFile in "${_oldRefactoringJar[@]}"; do
71 | echoErr " $jarFile"
72 | done
73 | exit 1
74 | fi
75 |
76 | OLD_REFACTORING_JAR="${_oldRefactoringJar[0]}"
77 | BACKUP_REFACTORING_JAR="$OLD_REFACTORING_JAR.$TSTAMP.bak"
78 | cp "$OLD_REFACTORING_JAR" "$BACKUP_REFACTORING_JAR"
79 | cp -i "$NEW_REFACTORING_JAR" "$OLD_REFACTORING_JAR"
80 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=0.13.16
2 |
--------------------------------------------------------------------------------
/project/build.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0")
2 |
3 | // We need to be able to not add the scoverage plugin to the scala-refactoring build.
4 | // This is necessary because scala-refactoring is built during Scala PR CI, which means
5 | // that we can not rely on any plugins that depend on scalac. For normal scala-refactoring
6 | // builds this variable should not be set, the Scala PR CI however nedes to set it.
7 | if (sys.env.contains("OMIT_SCOVERAGE_PLUGIN"))
8 | Nil
9 | else
10 | List(addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0"))
11 |
--------------------------------------------------------------------------------
/src/main/java/scala/tools/refactoring/JavadocStub.java:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring;
2 |
3 | /*
4 | * An empty class to create a fake javadoc.jar as required by Sonatype.
5 | *
6 | * */
7 | public class JavadocStub {
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/scala-2.10/README:
--------------------------------------------------------------------------------
1 | Put sources specifically for Scala-2.10.x here
--------------------------------------------------------------------------------
/src/main/scala-2.10/scala/tools/refactoring/ScalaVersionsAdapter.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 |
3 | object ScalaVersionAdapters {
4 | trait CompilerApiAdapters {
5 | val global: scala.tools.nsc.Global
6 | import global._
7 |
8 | def annotationInfoTree(info: AnnotationInfo): Tree = info.original
9 |
10 | def isImplementationArtifact(sym: Symbol): Boolean = {
11 | sym.isImplementationArtifact || {
12 | // Unfortunatley, for Scala-2.10, we can not rely on `isImplementationArtifact`
13 | // as this method might return wrong results. To mitigate this, we fall back to
14 | // the hack below:
15 | sym.name.toString.contains("$")
16 | }
17 | }
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/src/main/scala-2.10/scala/tools/refactoring/implementations/oimports/ImplicitValDefTraverserPF.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations.oimports
3 |
4 | import scala.tools.nsc.Global
5 |
6 | class ImplicitValDefTraverserPF[G <: Global](val global: G) {
7 | import global._
8 |
9 | /** Unsupported for Scala 2.10. */
10 | def apply(traverser: Traverser): PartialFunction[Tree, Unit] = PartialFunction.empty
11 | }
12 |
--------------------------------------------------------------------------------
/src/main/scala-2.11/README:
--------------------------------------------------------------------------------
1 | Put sources specifically for Scala-2.11.x here
--------------------------------------------------------------------------------
/src/main/scala-2.11/scala/tools/refactoring/ScalaVersionsAdapter.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 |
3 | object ScalaVersionAdapters {
4 | trait CompilerApiAdapters {
5 | val global: scala.tools.nsc.Global
6 | import global._
7 |
8 | def annotationInfoTree(info: AnnotationInfo): Tree = info.tree
9 |
10 | def isImplementationArtifact(sym: Symbol): Boolean = {
11 | sym.isImplementationArtifact
12 | }
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/scala-2.11/scala/tools/refactoring/implementations/oimports/ImplicitValDefTraverserPF.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations.oimports
3 |
4 | import scala.tools.nsc.Global
5 |
6 | class ImplicitValDefTraverserPF[G <: Global](val global: G) {
7 | import global._
8 | import global.analyzer._
9 |
10 | private def continueWithFunction(traverser: Traverser, rhs: Attachable) = rhs match {
11 | case TypeApply(fun, _) =>
12 | traverser.traverse(fun)
13 | }
14 |
15 | def apply(traverser: Traverser): PartialFunction[Tree, Unit] = {
16 | case ValDef(_, _, _, rhs: Attachable) if rhs.hasAttachment[MacroExpansionAttachment] =>
17 | val mea = rhs.attachments.get[MacroExpansionAttachment]
18 | mea.collect {
19 | case MacroExpansionAttachment(_, expanded: Typed) =>
20 | expanded
21 | }.foreach { expanded =>
22 | traverser.traverse(expanded.expr)
23 | }
24 | continueWithFunction(traverser, rhs)
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/src/main/scala-2.12/README:
--------------------------------------------------------------------------------
1 | Put sources specifically for Scala-2.11.x here
--------------------------------------------------------------------------------
/src/main/scala-2.12/scala/tools/refactoring/ScalaVersionsAdapter.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 |
3 | object ScalaVersionAdapters {
4 | trait CompilerApiAdapters {
5 | val global: scala.tools.nsc.Global
6 | import global._
7 |
8 | def annotationInfoTree(info: AnnotationInfo): Tree = info.tree
9 |
10 | def isImplementationArtifact(sym: Symbol): Boolean = {
11 | sym.isImplementationArtifact
12 | }
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/scala-2.12/scala/tools/refactoring/implementations/oimports/ImplicitValDefTraverserPF.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations.oimports
3 |
4 | import scala.tools.nsc.Global
5 |
6 | class ImplicitValDefTraverserPF[G <: Global](val global: G) {
7 | import global._
8 | import global.analyzer._
9 |
10 | private def continueWithFunction(traverser: Traverser, rhs: Attachable) = rhs match {
11 | case TypeApply(fun, _) =>
12 | traverser.traverse(fun)
13 | }
14 |
15 | def apply(traverser: Traverser): PartialFunction[Tree, Unit] = {
16 | case ValDef(_, _, _, rhs: Attachable) if rhs.hasAttachment[MacroExpansionAttachment] =>
17 | val mea = rhs.attachments.get[MacroExpansionAttachment]
18 | mea.collect {
19 | case MacroExpansionAttachment(_, expanded: Typed) =>
20 | expanded
21 | }.foreach { expanded =>
22 | traverser.traverse(expanded.expr)
23 | }
24 | continueWithFunction(traverser, rhs)
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/MultiStageRefactoring.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 |
7 | import common.Change
8 |
9 | /**
10 | * The super class of all refactoring implementations,
11 | * representing the several phases of the refactoring
12 | * process.
13 | */
14 | abstract class MultiStageRefactoring extends Refactoring {
15 |
16 | this: common.CompilerAccess =>
17 |
18 | /**
19 | * Preparing a refactoring can either return a result
20 | * or an instance of PreparationError, describing the
21 | * cause why the refactoring cannot be performed.
22 | */
23 |
24 | type PreparationResult
25 |
26 | case class PreparationError(cause: String)
27 |
28 | def prepare(s: Selection): Either[PreparationError, PreparationResult]
29 |
30 | /**
31 | * Refactorings are parameterized by the user, and to keep
32 | * them stateless, the result of the preparation step needs
33 | * to be passed to the perform method.
34 | *
35 | * The result can either be an error or a list of trees that
36 | * contain changes.
37 | */
38 |
39 | type RefactoringParameters
40 |
41 | case class RefactoringError(cause: String)
42 |
43 | def perform(selection: Selection, prepared: PreparationResult, params: RefactoringParameters): Either[RefactoringError, List[Change]]
44 |
45 | }
46 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/ParameterlessRefactoring.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 |
3 | import common.Change
4 |
5 | /**
6 | * A helper trait for refactorings that don't take RefactoringParameters.
7 | *
8 | * With this trait, the refactoring can implement the simplified perform
9 | * method.
10 | */
11 | trait ParameterlessRefactoring {
12 |
13 | this: MultiStageRefactoring =>
14 |
15 | class RefactoringParameters
16 |
17 | def perform(selection: Selection, prepared: PreparationResult, params: RefactoringParameters): Either[RefactoringError, List[Change]] = {
18 | perform(selection, prepared)
19 | }
20 |
21 | def perform(selection: Selection, prepared: PreparationResult): Either[RefactoringError, List[Change]]
22 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/Refactoring.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 |
7 | import scala.tools.nsc.io.AbstractFile
8 | import scala.tools.refactoring.common.Selections
9 | import scala.tools.refactoring.common.EnrichedTrees
10 | import scala.tools.refactoring.sourcegen.SourceGenerator
11 | import scala.tools.refactoring.transformation.TreeTransformations
12 | import scala.tools.refactoring.common.TextChange
13 | import scala.tools.refactoring.common.TracingImpl
14 |
15 | /**
16 | * The Refactoring trait combines the transformation and source generation traits with
17 | * their dependencies. Refactoring is mixed in by all concrete refactorings and can be
18 | * used by users of the library.
19 | */
20 | trait Refactoring extends Selections with TreeTransformations with TracingImpl with SourceGenerator with EnrichedTrees {
21 |
22 | this: common.CompilerAccess =>
23 |
24 | /**
25 | * Creates a list of changes from a list of (potentially changed) trees.
26 | *
27 | * @param A list of trees that are to be searched for modifications.
28 | * @return A list of changes that can be applied to the source file.
29 | */
30 | def refactor(changed: List[global.Tree]): List[TextChange] = context("main") {
31 | val changes = createChanges(changed)
32 | changes map minimizeChange
33 | }
34 |
35 | /**
36 | * Creates changes by applying a transformation to the root tree of an
37 | * abstract file.
38 | */
39 | def transformFile(file: AbstractFile, transformation: Transformation[global.Tree, global.Tree]): List[TextChange] = {
40 | refactor(transformation(abstractFileToTree(file)).toList)
41 | }
42 |
43 | /**
44 | * Creates changes by applying several transformations to the root tree
45 | * of an abstract file.
46 | * Each transformation creates a new root tree that is used as input of
47 | * the next transformation.
48 | */
49 | def transformFile(file: AbstractFile, transformations: List[Transformation[global.Tree, global.Tree]]): List[TextChange] = {
50 | def inner(root: global.Tree, ts: List[Transformation[global.Tree, global.Tree]]): Option[global.Tree] = {
51 | ts match {
52 | case t :: rest =>
53 | t(root) match {
54 | case Some(newRoot) => inner(newRoot, rest)
55 | case None => None
56 | }
57 | case Nil => Some(root)
58 | }
59 | }
60 |
61 | refactor(inner(abstractFileToTree(file), transformations).toList)
62 | }
63 |
64 | /**
65 | * Makes a generated change as small as possible by eliminating the
66 | * common pre- and suffix between the change and the source file.
67 | */
68 | private def minimizeChange(change: TextChange): TextChange = change match {
69 | case TextChange(file, from, to, changeText) =>
70 |
71 | def commonPrefixLength(s1: Seq[Char], s2: Seq[Char]) =
72 | (s1 zip s2 takeWhile Function.tupled(_ == _)).length
73 |
74 | val original = file.content.subSequence(from, to).toString
75 | val replacement = changeText
76 |
77 | val commonStart = commonPrefixLength(original, replacement)
78 | val commonEnd = commonPrefixLength(original.substring(commonStart).reverse, replacement.substring(commonStart).reverse)
79 |
80 | val minimizedChangeText = changeText.subSequence(commonStart, changeText.length - commonEnd).toString
81 | TextChange(file, from + commonStart, to - commonEnd, minimizedChangeText)
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/analysis/CompilationUnitIndexes.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package analysis
7 |
8 | import collection.mutable.HashMap
9 | import collection.mutable.ListBuffer
10 |
11 | /**
12 | * A CompilationUnitIndex is a light-weight index that
13 | * holds all definitions and references in a compilation
14 | * unit. This index is built with the companion object,
15 | * which traverses the whole compilation unit once and
16 | * then memoizes all relations.
17 | *
18 | */
19 | trait CompilationUnitIndexes {
20 |
21 | this: common.EnrichedTrees with common.CompilerAccess with common.TreeTraverser =>
22 |
23 | import global._
24 |
25 | trait CompilationUnitIndex {
26 | def root: Tree
27 | def definitions: Map[Symbol, List[DefTree]]
28 | def references: Map[Symbol, List[Tree]]
29 | }
30 | object CompilationUnitIndex {
31 |
32 |
33 | private lazy val scalaVersion = {
34 | val Version = "version (\\d+)\\.(\\d+)\\.(\\d+).*".r
35 | scala.util.Properties.versionString match {
36 | case Version(fst, snd, trd) => (fst.toInt, snd.toInt, trd.toInt)
37 | }
38 | }
39 |
40 | def apply(tree: Tree): CompilationUnitIndex = {
41 |
42 | assertCurrentThreadIsPresentationCompiler()
43 |
44 | val defs = new HashMap[Symbol, ListBuffer[DefTree]]
45 | val refs = new HashMap[Symbol, ListBuffer[Tree]]
46 |
47 | def addDefinition(s: Symbol, t: DefTree): Unit = {
48 | def add(s: Symbol) =
49 | defs.getOrElseUpdate(s, new ListBuffer[DefTree]) += t
50 |
51 | def isLowerScalaVersionThan2_10_1 = {
52 | scalaVersion._2 < 10 || scalaVersion._2 == 10 && scalaVersion._3 == 0
53 | }
54 |
55 | t.symbol match {
56 | case ts: TermSymbol if ts.isLazy && isLowerScalaVersionThan2_10_1 =>
57 | add(ts.lazyAccessor)
58 | case _ =>
59 | add(s)
60 | }
61 | }
62 |
63 | def addReference(s: Symbol, t: Tree): Unit = {
64 | def add(s: Symbol) =
65 | refs.getOrElseUpdate(s, new ListBuffer[Tree]) += t
66 |
67 | add(s)
68 |
69 | s match {
70 | case _: ClassSymbol => ()
71 | /*
72 | * If we only have a TypeSymbol, we check if it is
73 | * a reference to another symbol and add this to the
74 | * index as well.
75 | *
76 | * This is needed for example to find the TypeTree
77 | * of a DefDef parameter-ValDef
78 | * */
79 | case ts: TypeSymbol =>
80 | ts.info match {
81 | case tr: TypeRef if tr.sym != null && /*otherwise we get wrong matches because of Type-Aliases*/
82 | tr.sym.nameString == s.nameString =>
83 | add(tr.sym)
84 | case _ => ()
85 | }
86 | case _ => ()
87 | }
88 | }
89 |
90 | def handleSymbol(s: Symbol, t: Tree) = t match {
91 | case t: DefTree => addDefinition(s, t)
92 | case _ => addReference(s, t)
93 | }
94 |
95 | (new TreeWithSymbolTraverser(handleSymbol)).traverse(tree)
96 |
97 | new CompilationUnitIndex {
98 | val root = tree
99 | val definitions = defs.map{ case (sym, v) => sym.initialize → v.toList}.toMap
100 | val references = refs.map{ case (sym, v) => sym.initialize → v.toList}.toMap
101 | }
102 | }
103 | }
104 | }
105 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/analysis/ImportsToolbox.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package analysis
3 |
4 | /** Class to wrap path dependent type on `CompilationUnitDependencies` used in `CompilationUnitDependencies`. */
5 | class ImportsToolbox[C <: CompilationUnitDependencies with common.EnrichedTrees](val cuDependenciesInstance: C) {
6 | import cuDependenciesInstance.global._
7 |
8 | def apply(tree: Tree) = new IsSelectNotInRelativeImports(tree)
9 |
10 | /** Checks if for given `Select` potentially done from `TypeTree` exists import (represented by `Import`)
11 | * in this `Select` scope or its parent.
12 | */
13 | class IsSelectNotInRelativeImports(wholeTree: Tree) {
14 | private def collectPotentialOwners(of: Select): List[Symbol] = {
15 | var owners = List.empty[Symbol]
16 | def isSelectEmbracedByTree(tree: Tree): Boolean =
17 | tree.pos.isRange && tree.pos.start <= of.pos.start && of.pos.end <= tree.pos.end
18 | val collectPotentialOwners = new Traverser {
19 | var owns = List.empty[Symbol]
20 | override def traverse(t: Tree) = {
21 | owns = currentOwner :: owns
22 | t match {
23 | case potential if isSelectEmbracedByTree(potential) =>
24 | owners = owns.distinct
25 | super.traverse(t)
26 | case t =>
27 | super.traverse(t)
28 | owns = owns.tail
29 | }
30 | }
31 | }
32 | collectPotentialOwners.traverse(wholeTree)
33 | owners
34 | }
35 |
36 | /** Returns `true` if an import has been found for tested `Select` and `false` otherwise.
37 | *
38 | * Note: the examples below assume that `b` in `val baz: b` produces `TypeTree` which is
39 | * converted to `Select`. So examples are just a visualization of potential use case.
40 | *
41 | * Examples:
42 | * {{{
43 | * trait A {
44 | * import a.b
45 | * def foo = {
46 | * val baz: b = ???
47 | * }
48 | * }
49 | * }}}
50 | * returns `true`
51 | * {{{
52 | * def foo = {
53 | * import a.b
54 | * val baz: b = ???
55 | * }
56 | * }}}
57 | * returns `true`
58 | * but
59 | * {{{
60 | * def foo = {
61 | * val baz: b = ???
62 | * import a.b
63 | * }
64 | * }}}
65 | * returns `false`
66 | * For more see tests suites.
67 | */
68 | def apply(tested: Select): Boolean = {
69 | val doesNameFitInTested = compareNameWith(tested) _
70 | val nonPackageOwners = collectPotentialOwners(tested).filterNot { _.hasPackageFlag }
71 | def isValidPosition(t: Import): Boolean = t.pos.isRange && t.pos.start < tested.pos.start
72 | val isImportForTested = new Traverser {
73 | var found = false
74 | override def traverse(t: Tree) = t match {
75 | case imp: Import if isValidPosition(imp) && doesNameFitInTested(imp) && nonPackageOwners.contains(currentOwner) =>
76 | found = true
77 | case t => super.traverse(t)
78 | }
79 | }
80 | isImportForTested.traverse(wholeTree)
81 | !isImportForTested.found
82 | }
83 |
84 | private def compareNameWith(tested: Select)(that: Import): Boolean = {
85 | import cuDependenciesInstance.additionalTreeMethodsForPositions
86 | def mkName(t: Tree) = if (t.symbol != null && t.symbol != NoSymbol) t.symbol.fullNameString else t.nameString
87 | val Select(testedQual, testedName) = tested
88 | val testedQName = List(mkName(testedQual), testedName).mkString(".")
89 | val Import(thatQual, thatSels) = that
90 | val impNames = thatSels.flatMap { sel =>
91 | if (sel.name == nme.WILDCARD) List(mkName(thatQual))
92 | else Set(sel.name, sel.rename).map { name => List(mkName(thatQual), name).mkString(".") }.toList
93 | }
94 | impNames.exists { testedQName.startsWith }
95 | }
96 | }
97 | }
98 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/analysis/NameValidation.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package analysis
7 |
8 | import tools.nsc.ast.parser.Scanners
9 | import tools.nsc.ast.parser.Tokens
10 | import scala.util.control.NonFatal
11 | import scala.reflect.internal.util.BatchSourceFile
12 |
13 | /**
14 | * NameValidation offers several methods to validate
15 | * new names; depending on the context they are used.
16 | */
17 | trait NameValidation {
18 |
19 | self: Indexes with common.Selections with common.CompilerAccess =>
20 |
21 | import global._
22 |
23 | /**
24 | * Returns true if this name is a valid identifier,
25 | * as accepted by the Scala compiler.
26 | */
27 | def isValidIdentifier(name: String): Boolean = {
28 |
29 | val scanner = new { val global = self.global } with Scanners {
30 | val cu = new global.CompilationUnit(new BatchSourceFile("", name))
31 | val scanner = new UnitScanner(cu)
32 | }.scanner
33 |
34 | try {
35 | scanner.init()
36 | val firstTokenIsIdentifier = Tokens.isIdentifier(scanner.token)
37 |
38 | scanner.nextToken()
39 | val secondTokenIsEOF = scanner.token == Tokens.EOF
40 |
41 | firstTokenIsIdentifier && secondTokenIsEOF
42 | } catch {
43 | case NonFatal(_) => false
44 | }
45 | }
46 |
47 | /**
48 | * Returns all symbols that might collide with the new name
49 | * at the given symbol's location.
50 | *
51 | * For example, if the symbol is a method, it is checked if
52 | * there already exists a method with this name in the full
53 | * class hierarchy of that method's class.
54 | *
55 | * The implemented checks are only an approximation and not
56 | * necessarily correct.
57 | */
58 | def doesNameCollide(name: String, s: Symbol): List[Symbol] = {
59 |
60 | def isNameAlreadyUsedInLocalScope: List[Symbol] = {
61 | (index declaration s.owner map TreeSelection).toList flatMap {
62 | _.selectedSymbols.filter(_.nameString == name)
63 | }
64 | }
65 |
66 | def isNameAlreadyUsedInClassHierarchy = {
67 | index completeClassHierarchy s.owner flatMap (_.tpe.members) filter (_.nameString == name)
68 | }
69 |
70 | def isNameAlreadyUsedInPackageHierarchy = {
71 | index completePackageHierarchy s.owner flatMap (_.tpe.members) filter (_.nameString == name)
72 | }
73 |
74 | val owner = s.owner
75 |
76 | if(s.isPrivate || s.isLocal) {
77 | isNameAlreadyUsedInLocalScope.distinct
78 | } else if(owner.isClass && !(owner.isModuleClass || owner.isClass && nme.isLocalName(owner.name))) {
79 | isNameAlreadyUsedInClassHierarchy.distinct
80 | } else {
81 | isNameAlreadyUsedInPackageHierarchy.distinct
82 | }
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/analysis/TreeAnalysis.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package analysis
7 |
8 | import common.Selections
9 |
10 | /**
11 | * Provides some simple methods to analyze the program's
12 | * data flow, as used by Extract Method to find in and out
13 | * parameters.
14 | */
15 | trait TreeAnalysis {
16 |
17 | this: Selections with Indexes with common.CompilerAccess =>
18 |
19 | /**
20 | * From the selection and in the scope of the currentOwner, returns
21 | * a list of all symbols that are owned by currentOwner and used inside
22 | * but declared outside the selection.
23 | */
24 | @deprecated("use selection.inbondLocalDeps instead", "0.6")
25 | def inboundLocalDependencies(selection: Selection, currentOwner: global.Symbol): List[global.Symbol] = {
26 |
27 | val allLocalSymbols = selection.selectedSymbols filter {
28 | _.ownerChain.contains(currentOwner)
29 | }
30 |
31 | allLocalSymbols.filterNot {
32 | index.declaration(_).map(selection.contains) getOrElse true
33 | }.filter(t => t.pos.isOpaqueRange).sortBy(_.pos.start).distinct
34 | }
35 |
36 | /**
37 | * From the selection and in the scope of the currentOwner, returns
38 | * a list of all symbols that are defined inside the selection and
39 | * used outside of it.
40 | */
41 | @deprecated("use selection.outboundLocalDeps instead", "0.6")
42 | def outboundLocalDependencies(selection: Selection): List[global.Symbol] = {
43 |
44 | val declarationsInTheSelection = selection.selectedSymbols filter (s => index.declaration(s).map(selection.contains) getOrElse false)
45 |
46 | val occurencesOfSelectedDeclarations = declarationsInTheSelection flatMap (index.occurences)
47 |
48 | occurencesOfSelectedDeclarations.filterNot(selection.contains).map(_.symbol).distinct
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/common/Change.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package common
7 |
8 | import scala.tools.nsc.io.AbstractFile
9 | import scala.reflect.internal.util.SourceFile
10 |
11 | /**
12 | * The common interface for all changes.
13 | */
14 | sealed trait Change
15 |
16 | case class TextChange(sourceFile: SourceFile, from: Int, to: Int, text: String) extends Change {
17 |
18 | def file = sourceFile.file
19 |
20 | /**
21 | * Instead of a change to an existing file, return a change that creates a new file
22 | * with the change applied to the original file.
23 | *
24 | * @param fullNewName The fully qualified package name of the target.
25 | */
26 | def toNewFile(fullNewName: String) = {
27 | val src = Change.applyChanges(List(this), new String(sourceFile.content))
28 | NewFileChange(fullNewName, src)
29 | }
30 | }
31 |
32 | /**
33 | * The changes creates a new source file, indicated by the `fullName` parameter. It is of
34 | * the form "some.package.FileName".
35 | */
36 | case class NewFileChange(fullName: String, text: String) extends Change
37 |
38 | case class MoveToDirChange(sourceFile: AbstractFile, to: String) extends Change
39 |
40 | case class RenameSourceFileChange(sourceFile: AbstractFile, to: String) extends Change
41 |
42 | object Change {
43 | /**
44 | * Applies the list of changes to the source string. NewFileChanges are ignored.
45 | * Primarily used for testing / debugging.
46 | */
47 | def applyChanges(ch: List[Change], source: String): String = {
48 | val changes = ch collect {
49 | case tc: TextChange => tc
50 | }
51 |
52 | val sortedChanges = changes.sortBy { descendingTo }
53 |
54 | /* Test if there are any overlapping text edits. This is
55 | not necessarily an error, but Eclipse doesn't allow
56 | overlapping text edits, and this helps us catch them
57 | in our own tests. */
58 | sortedChanges.sliding(2).toList foreach {
59 | case List(TextChange(_, from, _, _), TextChange(_, _, to, _)) =>
60 | assert(from >= to)
61 | case _ => ()
62 | }
63 |
64 | (source /: sortedChanges) { (src, change) =>
65 | src.take(change.from) + change.text + src.drop(change.to)
66 | }
67 | }
68 |
69 | private def descendingTo(change: TextChange) = -change.to
70 |
71 | case class AcceptReject(accepted: List[Change], rejected: List[Change])
72 |
73 | def discardOverlappingChanges(changes: List[Change]): AcceptReject = {
74 | val applicableChanges = changes.collect {
75 | case tc: TextChange => tc
76 | }.sortBy { descendingTo }
77 | applicableChanges.foldLeft(AcceptReject(Nil, Nil)) { (acc, ch) => acc.accepted match {
78 | case Nil => acc.copy(accepted = ch :: acc.accepted)
79 | case (h: TextChange) :: _ if h.from >= ch.to => acc.copy(accepted = ch :: acc.accepted)
80 | case _ => acc.copy(rejected = ch :: acc.rejected)
81 | } }
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/common/CompilerAccess.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package common
7 |
8 | import tools.nsc.io.AbstractFile
9 |
10 | trait CompilerAccess {
11 |
12 | val global: tools.nsc.Global
13 |
14 | def compilationUnitOfFile(f: AbstractFile): Option[global.CompilationUnit]
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/common/CompilerApiExtensions.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.common
2 |
3 | /*
4 | * FIXME: This class duplicates functionality from org.scalaide.core.compiler.CompilerApiExtensions.
5 | */
6 | trait CompilerApiExtensions {
7 | this: CompilerAccess =>
8 | import global._
9 |
10 | /** Locate the smallest tree that encloses position.
11 | *
12 | * @param tree The tree in which to search `pos`
13 | * @param pos The position to look for
14 | * @param p An additional condition to be satisfied by the resulting tree
15 | * @return The innermost enclosing tree for which p is true, or `EmptyTree`
16 | * if the position could not be found.
17 | */
18 | def locateIn(tree: Tree, pos: Position, p: Tree => Boolean = t => true): Tree =
19 | new FilteringLocator(pos, p) locateIn tree
20 |
21 | def enclosingPackage(tree: Tree, pos: Position): Tree = {
22 | locateIn(tree, pos, _.isInstanceOf[PackageDef])
23 | }
24 |
25 | private class FilteringLocator(pos: Position, p: Tree => Boolean) extends Locator(pos) {
26 | override def isEligible(t: Tree) = super.isEligible(t) && p(t)
27 | }
28 |
29 | /*
30 | * For Scala-2.10 (see scala.reflect.internal.Positions.Locator in Scala-2.11).
31 | */
32 | private class Locator(pos: Position) extends Traverser {
33 | var last: Tree = _
34 | def locateIn(root: Tree): Tree = {
35 | this.last = EmptyTree
36 | traverse(root)
37 | this.last
38 | }
39 | protected def isEligible(t: Tree) = !t.pos.isTransparent
40 | override def traverse(t: Tree): Unit = {
41 | t match {
42 | case tt : TypeTree if tt.original != null && (tt.pos includes tt.original.pos) =>
43 | traverse(tt.original)
44 | case _ =>
45 | if (t.pos includes pos) {
46 | if (isEligible(t)) last = t
47 | super.traverse(t)
48 | } else t match {
49 | case mdef: MemberDef =>
50 | traverseTrees(mdef.mods.annotations)
51 | case _ =>
52 | }
53 | }
54 | }
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/common/InteractiveScalaCompiler.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package common
7 |
8 | import tools.nsc.io.AbstractFile
9 | import scala.reflect.internal.util.SourceFile
10 |
11 | /**
12 | * Many parts of the library can work with the non-interactive global,
13 | * but some -- most notably the refactoring implementations -- need an
14 | * interactive compiler, which is expressed by this trait.
15 | */
16 | trait InteractiveScalaCompiler extends CompilerAccess {
17 |
18 | val global: tools.nsc.interactive.Global
19 |
20 | def compilationUnitOfFile(f: AbstractFile) = global.unitOfFile.get(f)
21 |
22 | /**
23 | * Returns a fully loaded and typed Tree instance for the given SourceFile.
24 | */
25 | def askLoadedAndTypedTreeForFile(file: SourceFile): Either[global.Tree, Throwable] = {
26 | val r = new global.Response[global.Tree]
27 | global.askLoadedTyped(file, r)
28 | r.get
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/common/Occurrences.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.common
2 |
3 | import scala.tools.refactoring.analysis.Indexes
4 | import scala.reflect.internal.util.RangePosition
5 | import scala.reflect.internal.util.OffsetPosition
6 |
7 | /**
8 | * Provides functionalities to get positions of term names. This includes the term name
9 | * defintion and all its uses.
10 | */
11 | trait Occurrences extends Selections with CompilerAccess with Indexes {
12 | import global._
13 |
14 | type Occurrence = (Int, Int)
15 |
16 | private def termNameDefinition(root: Tree, name: String) = {
17 | root.collect {
18 | case t: DefTree if t.name.decode == name =>
19 | t
20 | }.headOption
21 | }
22 |
23 | private def defToOccurrence(t: DefTree) = t.namePosition() match {
24 | case p: RangePosition =>
25 | (p.start, p.end - p.start)
26 | case p: OffsetPosition =>
27 | (p.point, t.name.decode.length)
28 | }
29 |
30 | private def refToOccurrence(t: RefTree) = t.pos match {
31 | case p: RangePosition =>
32 | (p.start, p.end - p.start)
33 | case p: OffsetPosition =>
34 | (p.point, t.symbol.name.decode.length)
35 | }
36 |
37 | /**
38 | * Returns all uses of the term name introduced by the DefTree t.
39 | */
40 | def allOccurrences(t: DefTree): List[Occurrence] = {
41 | val refOccurrences = index.references(t.symbol).collect {
42 | case ref: RefTree => refToOccurrence(ref)
43 | }
44 | defToOccurrence(t) :: refOccurrences
45 | }
46 |
47 | /**
48 | * Searches for a definition of `name` in `root` and returns is's position
49 | * and all positions of references to the definition.
50 | * Returns an empty list if the definition of `name` is not in `selection`.
51 | */
52 | def termNameOccurrences(root: Tree, name: String): List[Occurrence] = {
53 | termNameDefinition(root, name) match {
54 | case Some(t) => allOccurrences(t)
55 | case None => Nil
56 | }
57 | }
58 |
59 | /**
60 | * Searches for a method definition of `defName` in `root` and returns for each
61 | * method parameter a list with the position of the parameter definition and
62 | * all occurrences.
63 | * Returns an empty list if `defName` is not found, defines something that is not
64 | * a method or if the method has no parameters.
65 | */
66 | def defDefParameterOccurrences(root: Tree, defName: String): List[List[Occurrence]] = {
67 | termNameDefinition(root, defName) match {
68 | case Some(DefDef(_, _, _, params, _, _)) =>
69 | params.flatten.map { p => allOccurrences(p) }
70 | case _ => Nil
71 | }
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/common/PositionDebugging.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.common
2 |
3 | import scala.reflect.api.Position
4 | import scala.reflect.internal.util.NoPosition
5 | import scala.reflect.internal.util.NoSourceFile
6 | import scala.tools.refactoring.getSimpleClassName
7 |
8 | /**
9 | * Some utilities for debugging purposes.
10 | */
11 | object PositionDebugging {
12 | def format(pos: Position): String = {
13 | formatInternal(pos, false)
14 | }
15 |
16 | def formatCompact(pos: Position): String = {
17 | formatInternal(pos, true)
18 | }
19 |
20 | def format(start: Int, end: Int, source: Array[Char]): String = {
21 | def slice(start: Int, end: Int): String = {
22 | source.view(start, end).mkString("").replace("\r\n", "\\r\\n").replace("\n", "\\n")
23 | }
24 |
25 | val ctxChars = 10
26 | val l = slice(start - ctxChars, start)
27 | val m = slice(start, end)
28 | val r = slice(end, end + ctxChars)
29 | s"$l«$m»$r".trim
30 | }
31 |
32 | private def formatInternal(pos: Position, compact: Boolean): String = {
33 | if (pos != NoPosition && pos.source != NoSourceFile) {
34 | val posType = getSimpleClassName(pos)
35 |
36 | val (start, point, end) = {
37 | if (!pos.isRange) (pos.point, pos.point, pos.point)
38 | else (pos.start, pos.point, pos.end)
39 | }
40 |
41 | val markerString = {
42 | if (start == end) s"($start)"
43 | else s"($start, $point, $end)"
44 | }
45 |
46 | val relevantSource = {
47 | if (compact) ""
48 | else "[" + format(start, end, pos.source.content) + "]"
49 | }
50 |
51 | s"$posType$markerString$relevantSource"
52 | } else {
53 | "UndefinedPosition"
54 | }
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/common/TracingHelpers.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.common
2 |
3 | object TracingHelpers {
4 | def compactify(text: String): String = {
5 | val lines = text.lines.toArray
6 | val firstLine = lines.headOption.getOrElse("")
7 | val snipAfter = 50
8 |
9 | val (compactedFirstLine, dotsAdded) = {
10 | if (firstLine.size <= snipAfter) {
11 | (firstLine, false)
12 | } else {
13 | (firstLine.substring(0, snipAfter) + "...", true)
14 | }
15 | }
16 |
17 | if (lines.size <= 1) {
18 | compactedFirstLine
19 | } else {
20 | val compactedFirstLineWithDots = {
21 | if (dotsAdded) compactedFirstLine
22 | else compactedFirstLine + "..."
23 | }
24 |
25 | val moreLines = lines.size - 1
26 | compactedFirstLineWithDots + s"($moreLines more lines ommitted)"
27 | }
28 | }
29 |
30 | def toCompactString(any: Any): String = compactify("" + any)
31 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/common/TreeExtractors.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package common
7 |
8 | import PartialFunction.cond
9 |
10 | trait TreeExtractors {
11 |
12 | this: common.CompilerAccess =>
13 |
14 | import global._
15 |
16 | object Names {
17 | lazy val scala = newTermName("scala")
18 | lazy val pkg = newTermName("package")
19 | lazy val None = newTermName("None")
20 | lazy val Some = newTermName("Some")
21 | lazy val Predef = newTermName("Predef")
22 | lazy val reflect = newTermName("reflect")
23 | lazy val apply = newTermName("apply")
24 | lazy val Nil = newTermName("Nil")
25 | lazy val immutable = newTypeName("immutable")
26 | lazy val :: = newTermName("$colon$colon")
27 | lazy val List = newTermName("List")
28 | lazy val Seq = newTermName("Seq")
29 | lazy val collection = newTermName("collection")
30 | lazy val immutableTerm = newTermName("immutable")
31 | lazy val scalaType = newTypeName("scala")
32 | }
33 |
34 | /**
35 | * An extractor for the Some constructor.
36 | */
37 | object SomeExpr {
38 | def unapply(t: Tree): Option[Tree] = t match {
39 | case Apply(TypeApply(Select(Select(Ident(Names.scala), Names.Some), _), (_: TypeTree) :: Nil), argToSome :: Nil) => Some(argToSome)
40 | case _ => None
41 | }
42 | }
43 |
44 | /**
45 | * A boolean extractor for the None constructor.
46 | */
47 | object NoneExpr {
48 | def unapply(t: Tree) = cond(t) {
49 | case Select(Ident(Names.scala), Names.None) => true
50 | }
51 | }
52 |
53 | /**
54 | * An extractor for the List constructor `List` or ::
55 | */
56 | object ListExpr {
57 | def unapply(t: Tree): Option[Tree] = t match {
58 | case Block(
59 | (ValDef(_, v1, _, arg)) :: Nil, Apply(TypeApply(Select(Select(This(Names.immutable), Names.Nil), Names.::), (_: TypeTree) :: Nil),
60 | Ident(v2) :: Nil)) if v1 == v2 =>
61 | Some(arg)
62 | case Apply(TypeApply(Select(Select(This(Names.immutable), Names.List), Names.apply), (_: TypeTree) :: Nil), arg :: Nil) =>
63 | Some(arg)
64 | case _ =>
65 | None
66 | }
67 | }
68 |
69 | /**
70 | * A boolean extractor for the Nil object.
71 | */
72 | object NilExpr {
73 | def unapply(t: Tree) = cond(t) {
74 | case Select(This(Names.immutable), Names.Nil) => true
75 | }
76 | }
77 |
78 | /**
79 | * An extractor for the () literal tree.
80 | */
81 | object UnitLit {
82 | def unapply(t: Tree) = cond(t) {
83 | case Literal(c) => c.tag == UnitTag
84 | }
85 | }
86 |
87 | /**
88 | * An extractor that returns the name of a tree's
89 | * type as a String.
90 | */
91 | object HasType {
92 |
93 | def getTypeName(t: Type): Option[String] = t match {
94 | case TypeRef(_, sym, _) =>
95 | Some(sym.nameString)
96 | case ConstantType(value) =>
97 | getTypeName(value.tpe)
98 | case _ => None
99 | }
100 |
101 | def unapply(t: Tree) = {
102 | getTypeName(t.tpe)
103 | }
104 | }
105 |
106 | /**
107 | * True if the tree's type is Unit
108 | */
109 | def hasUnitType(t: Tree) = t.tpe match {
110 | case TypeRef(_, sym, _) => sym == definitions.UnitClass
111 | case _ => false
112 | }
113 | }
114 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/common/exceptions.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package common
7 |
8 | class TreeNotFound(file: String) extends Exception("Tree not found for file "+ file +".")
9 |
10 | class RefactoringError(cause: String) extends Exception(cause)
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/common/package.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 |
3 | package object common {
4 | /**
5 | * The selected tracing implementation.
6 | *
7 | * Use [[SilentTracing]] for production; consider [[DebugTracing]] for debugging.
8 | */
9 | type TracingImpl = SilentTracing
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/common/tracing.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package common
7 |
8 | import java.io.File
9 | import java.io.FileOutputStream
10 | import java.io.PrintStream
11 |
12 | trait Tracing {
13 | protected abstract class TraceAndReturn[T] {
14 | def \\(trace: T => Unit): T
15 | }
16 |
17 | protected implicit def wrapInTraceAndReturn[T](t: T): TraceAndReturn[T]
18 |
19 | def context[T](name: String)(body: => T): T
20 |
21 | def trace(msg: => String, arg1: => Any, args: Any*): Unit
22 |
23 | def trace(msg: => String): Unit
24 | }
25 |
26 | object DebugTracing {
27 | private val debugStream = {
28 | Option(System.getProperty("scala.refactoring.traceFile")).flatMap { fn =>
29 | try {
30 | val traceFile = new File(fn)
31 | val out = new FileOutputStream(traceFile, true)
32 | Some(new PrintStream(out, true, "UTF-8"))
33 | } catch {
34 | case e: Exception =>
35 | e.printStackTrace()
36 | System.err.println(s"Could not open '$fn' for writing; falling back to System.out...")
37 | None
38 | }
39 | }.getOrElse(System.out)
40 | }
41 |
42 | private def printLine(str: String) = {
43 | debugStream.println(str)
44 | }
45 | }
46 |
47 | /**
48 | * Traces to STDOUT or a custom file (via the system property `scala.refactoring.traceFile`)
49 | */
50 | trait DebugTracing extends Tracing {
51 | import DebugTracing._
52 |
53 | var level = 0
54 | val marker = "│"
55 | val indent = " "
56 |
57 | override def context[T](name: String)(body: => T): T = {
58 |
59 | val spacer = "─" * (indent.length - 1)
60 |
61 | printLine((indent * level) +"╰"+ spacer +"┬────────" )
62 | level += 1
63 | trace("→ "+ name)
64 |
65 | body \\ { _ =>
66 | level -= 1
67 | printLine((indent * level) + "╭"+ spacer +"┴────────" )
68 | }
69 | }
70 |
71 | override def trace(msg: => String, arg1: => Any, args: Any*): Unit = {
72 |
73 | val as: Array[AnyRef] = arg1 +: args.toArray map {
74 | case s: String => "«"+ s.replaceAll("\n", "\\\\n") +"»"
75 | case a: AnyRef => a
76 | }
77 |
78 | trace(msg.format(as: _*))
79 | }
80 |
81 | override def trace(msg: => String): Unit = {
82 | val border = (indent * level) + marker
83 | printLine(border + msg.replaceAll("\n", "\n"+ border))
84 | }
85 |
86 | protected implicit final def wrapInTraceAndReturn[T](t: T) = new TraceAndReturn[T] {
87 | def \\(trace: T => Unit) = {
88 | trace(t)
89 | t
90 | }
91 | }
92 | }
93 |
94 | trait SilentTracing extends Tracing {
95 | @inline
96 | def trace(msg: => String, arg1: => Any, args: Any*) = ()
97 |
98 | @inline
99 | def trace(msg: => String) = ()
100 |
101 | @inline
102 | def context[T](name: String)(body: => T): T = body
103 |
104 | protected implicit final def wrapInTraceAndReturn[T](t: T) = new TraceAndReturn[T] {
105 | def \\(ignored: T => Unit) = t
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/AddField.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations
3 |
4 | import scala.tools.nsc.io.AbstractFile
5 | import scala.tools.refactoring.common.TextChange
6 | import scala.tools.nsc.ast.parser.Tokens
7 |
8 | abstract class AddField extends AddValOrDef {
9 |
10 | val global: tools.nsc.interactive.Global
11 | import global._
12 |
13 | def addField(file: AbstractFile, className: String, valName: String, isVar: Boolean, returnTypeOpt: Option[String], target: AddMethodTarget): List[TextChange] =
14 | addValOrDef(file, className, target, addField(valName, isVar, returnTypeOpt, _))
15 |
16 | private def addField(valName: String, isVar: Boolean, returnTypeOpt: Option[String], classOrObjectDef: Tree): List[TextChange] = {
17 | val returnStatement = Ident("???")
18 |
19 | val returnType = returnTypeOpt.map(name => TypeTree(newType(name))).getOrElse(new TypeTree)
20 |
21 | val mods = if (isVar) Modifiers(Flag.MUTABLE) withPosition (Tokens.VAR, NoPosition) else NoMods
22 |
23 | val newVal = mkValOrVarDef(mods, valName, returnStatement, returnType)
24 |
25 | val insertField = insertDef(newVal)
26 |
27 | refactor((insertField apply classOrObjectDef).toList)
28 | }
29 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/AddImportStatement.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package implementations
7 |
8 | import common.InteractiveScalaCompiler
9 | import common.Change
10 | import scala.tools.nsc.io.AbstractFile
11 | import scala.tools.refactoring.common.TextChange
12 |
13 | abstract class AddImportStatement extends Refactoring with InteractiveScalaCompiler {
14 |
15 | override val global: tools.nsc.interactive.Global
16 |
17 | def addImport(file: AbstractFile, fqName: String): List[TextChange] = addImports(file, List(fqName))
18 |
19 | def addImports(file: AbstractFile, importsToAdd: Iterable[String]): List[TextChange] = {
20 |
21 | val astRoot = abstractFileToTree(file)
22 |
23 | addImportTransformation(importsToAdd.toSeq)(astRoot).toList
24 | }
25 |
26 | @deprecated("Use addImport(file, ..) instead", "0.4.0")
27 | def addImport(selection: Selection, fullyQualifiedName: String): List[Change] = {
28 | addImport(selection.file, fullyQualifiedName)
29 | }
30 |
31 | @deprecated("Use addImport(file, ..) instead", "0.4.0")
32 | def addImport(selection: Selection, pkg: String, name: String): List[Change] = {
33 | addImport(selection.file, pkg +"."+ name)
34 | }
35 |
36 | @deprecated("Not needed anymore, don't override.", "0.4.0")
37 | def getContentForFile(file: AbstractFile): Array[Char] = throw new UnsupportedOperationException
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/AddMethod.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations
3 |
4 | import scala.tools.nsc.io.AbstractFile
5 | import scala.tools.refactoring.common.TextChange
6 |
7 | abstract class AddMethod extends AddValOrDef {
8 |
9 | val global: tools.nsc.interactive.Global
10 | import global._
11 |
12 | def addMethod(file: AbstractFile, className: String, methodName: String, parameters: List[List[(String, String)]], typeParameters: List[String], returnType: Option[String], target: AddMethodTarget): List[TextChange] =
13 | addValOrDef(file, className, target, addMethod(methodName, parameters, typeParameters, returnType, _))
14 |
15 | def addMethod(file: AbstractFile, className: String, methodName: String, parameters: List[List[(String, String)]], returnType: Option[String], target: AddMethodTarget): List[TextChange] =
16 | addMethod(file, className, methodName, parameters, Nil, returnType, target)
17 |
18 | private def addMethod(methodName: String, parameters: List[List[(String, String)]], typeParameters: List[String], returnTypeOpt: Option[String], classOrObjectDef: Tree): List[TextChange] = {
19 | val nscParameters = for (paramList <- parameters) yield for ((paramName, typeName) <- paramList) yield {
20 | val paramSymbol = NoSymbol.newValue(newTermName(paramName))
21 | paramSymbol.setInfo(newType(typeName))
22 | paramSymbol
23 | }
24 |
25 | val typeParams = if (typeParameters.nonEmpty) {
26 | val typeDef = new TypeDef(NoMods, newTypeName(typeParameters.mkString(", ")), Nil, EmptyTree)
27 | List(typeDef)
28 | } else Nil
29 |
30 | val returnStatement = Ident("???") :: Nil
31 |
32 | val newDef = mkDefDef(NoMods, methodName, nscParameters, returnStatement, typeParams, returnTypeOpt = returnTypeOpt.map(name => TypeTree(newType(name))))
33 |
34 | val insertMethod = insertDef(newDef)
35 |
36 | refactor((insertMethod apply classOrObjectDef).toList)
37 | }
38 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/AddValOrDef.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations
3 |
4 | import scala.tools.nsc.io.AbstractFile
5 | import scala.tools.refactoring.Refactoring
6 | import scala.tools.refactoring.common.TextChange
7 |
8 | import common.InteractiveScalaCompiler
9 |
10 | trait AddValOrDef extends Refactoring with InteractiveScalaCompiler {
11 |
12 | val global: tools.nsc.interactive.Global
13 | import global._
14 |
15 | def addValOrDef(file: AbstractFile, className: String, target: AddMethodTarget, changeFunc: Tree => List[TextChange]): List[TextChange] = {
16 | val astRoot = abstractFileToTree(file)
17 |
18 | //it would be nice to pass in the symbol and use that rather than compare the name, but it might not be available
19 | val classOrObjectDef = target match {
20 | case AddToClosest(offset: Int) => {
21 | case class UnknownDef(tree: Tree, offset: Int)
22 |
23 | val classAndObjectDefs = astRoot.collect {
24 | case classDef: ClassDef if classDef.name.decode == className =>
25 | UnknownDef(classDef, classDef.namePosition.point)
26 | case moduleDef: ModuleDef if moduleDef.name.decode == className =>
27 | UnknownDef(moduleDef, moduleDef.namePosition.point)
28 | }
29 |
30 | //the class/object definition just before the given offset
31 | classAndObjectDefs.sortBy(_.offset).reverse.find(_.offset < offset).map(_.tree)
32 | }
33 | case _ => {
34 | astRoot.find {
35 | case classDef: ClassDef if target == AddToClass => classDef.name.decode == className
36 | case moduleDef: ModuleDef if target == AddToObject => moduleDef.name.decode == className
37 | case _ => false
38 | }
39 | }
40 | }
41 |
42 | changeFunc(classOrObjectDef.get)
43 | }
44 |
45 | protected def insertDef(valOrDef: ValOrDefDef) = {
46 | def addMethodToTemplate(tpl: Template) = tpl copy (body = tpl.body ::: valOrDef :: Nil) replaces tpl
47 |
48 | transform {
49 | case implDef: ImplDef => addMethodToTemplate(implDef.impl)
50 | }
51 | }
52 |
53 | protected def newType(name: String) = new Type {
54 | override def safeToString: String = name
55 | }
56 | }
57 |
58 | sealed trait AddMethodTarget
59 | case object AddToClass extends AddMethodTarget
60 | case object AddToObject extends AddMethodTarget
61 | case class AddToClosest(offset: Int) extends AddMethodTarget
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/ChangeParamOrder.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations
3 |
4 |
5 | /**
6 | * Refactoring that changes the order of the parameters of a method.
7 | */
8 | abstract class ChangeParamOrder extends MethodSignatureRefactoring {
9 |
10 | import global._
11 |
12 | type Permutation = List[Int]
13 | /**
14 | * There has to be a permutation for each parameter list of the selected method.
15 | */
16 | type RefactoringParameters = List[Permutation]
17 |
18 | override def checkRefactoringParams(prep: PreparationResult, affectedDefs: AffectedDefs, params: RefactoringParameters) =
19 | (prep.defdef.vparamss corresponds params) ((vp, p) => (0 until vp.length).toList == (p.sortWith(_ < _)))
20 |
21 | def reorder[T](origVparamss: List[List[T]], permutations: List[Permutation]): List[List[T]] =
22 | (origVparamss zip permutations) map {
23 | case (params, perm) => reorderSingleParamList(params, perm)
24 | }
25 |
26 | def reorderSingleParamList[T](origVparams: List[T], permutation: Permutation) =
27 | permutation map origVparams
28 |
29 | override def defdefRefactoring(parameters: RefactoringParameters) = transform {
30 | case orig @ DefDef(mods, name, tparams, vparams, tpt, rhs) => {
31 | val reorderedVparams = reorder(vparams, parameters)
32 | DefDef(mods, name, tparams, reorderedVparams, tpt, rhs) replaces orig
33 | }
34 | }
35 |
36 | override def applyRefactoring(params: RefactoringParameters) = transform {
37 | case orig @ Apply(fun, args) => {
38 | val pos = paramListPos(findOriginalTree(orig)) - 1
39 | val reorderedArgs = reorderSingleParamList(args, params(pos))
40 | Apply(fun, reorderedArgs) replaces orig
41 | }
42 | }
43 |
44 | override def prepareParamsForSingleRefactoring(originalParams: RefactoringParameters, selectedMethod: DefDef, toRefactor: DefInfo): RefactoringParameters = {
45 | val toDrop = originalParams.size - toRefactor.nrParamLists
46 | val touchablesPrepared = originalParams.drop(toDrop)
47 | val nrUntouchables = toRefactor.nrUntouchableParamLists
48 | val ids = originalParams.drop(toDrop - nrUntouchables).take(nrUntouchables) map { perm =>
49 | (0 until perm.size).toList
50 | }
51 | ids match {
52 | case Nil => touchablesPrepared
53 | case _ => ids:::originalParams.drop(toDrop)
54 | }
55 | }
56 |
57 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/ClassParameterDrivenSourceGeneration.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations
3 |
4 | import scala.tools.refactoring.MultiStageRefactoring
5 |
6 | import common.Change
7 |
8 | /**
9 | * Baseclass for refactorings that generate class-level source based on
10 | * the class parameters.
11 | */
12 | abstract class ClassParameterDrivenSourceGeneration extends MultiStageRefactoring with common.InteractiveScalaCompiler {
13 |
14 | import global._
15 |
16 | case class PreparationResult(
17 | classDef: ClassDef,
18 | classParams: List[(ValDef, Boolean)],
19 | existingEqualityMethods: List[ValOrDefDef])
20 |
21 | /** A function that takes a class parameter name and decides
22 | * whether this parameter should be used in equals/hashCode
23 | * computations, a boolean that indicates whether calls
24 | * to super should be used or not, and a boolean that indicates
25 | * whether existing equality methods (equals, canEqual and hashCode)
26 | * should be kept or replaced.
27 | */
28 | case class RefactoringParameters(
29 | callSuper: Boolean = true,
30 | paramsFilter: ValDef => Boolean,
31 | keepExistingEqualityMethods: Boolean)
32 |
33 | def prepare(s: Selection) = {
34 | val notAClass = Left(PreparationError("No class definition selected."))
35 | s.findSelectedOfType[ClassDef] match {
36 | case None => notAClass
37 | case Some(classDef) if classDef.hasSymbol && classDef.symbol.isTrait => notAClass
38 | case Some(classDef) => {
39 | failsBecause(classDef).map(PreparationError(_)) toLeft {
40 | val classParams = classDef.impl.nonPrivateClassParameters
41 | val equalityMethods = classDef.impl.existingEqualityMethods
42 | PreparationResult(classDef, classParams, equalityMethods)
43 | }
44 | }
45 | }
46 | }
47 |
48 | def failsBecause(classDef: ClassDef): Option[String]
49 |
50 | def perform(selection: Selection, prep: PreparationResult, params: RefactoringParameters): Either[RefactoringError, List[Change]] = {
51 | val selectedParams = prep.classParams.map(_._1) filter params.paramsFilter
52 |
53 | val templateFilter = filter {
54 | case prep.classDef.impl => true
55 | }
56 |
57 | val refactoring = topdown {
58 | matchingChildren {
59 | templateFilter &> sourceGeneration(selectedParams, prep, params)
60 | }
61 | }
62 |
63 | Right(transformFile(selection.file, refactoring))
64 | }
65 |
66 | def sourceGeneration(params: List[ValDef], preparationResult: PreparationResult, refactoringParams: RefactoringParameters): Transformation[Tree, Tree]
67 |
68 | }
69 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/ExpandCaseClassBinding.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package implementations
7 |
8 | import common.Change
9 | import scala.tools.refactoring.analysis.GlobalIndexes
10 |
11 | abstract class ExpandCaseClassBinding extends MultiStageRefactoring with ParameterlessRefactoring with GlobalIndexes {
12 |
13 | this: common.CompilerAccess =>
14 |
15 | import global._
16 |
17 | case class PreparationResult(bind: Bind, sym: ClassSymbol, body: Tree)
18 |
19 | def prepare(s: Selection): Either[PreparationError, PreparationResult] = {
20 |
21 | def failure = Left(PreparationError("No binding to expand found. Please select a binding in a case clause."))
22 |
23 | val res = s.findSelectedOfType[CaseDef] flatMap { caseDef =>
24 | s.findSelectedOfType[Bind] map {
25 | case bind @ Bind(_, body) =>
26 | body.tpe match {
27 | case TypeRef(_, sym: ClassSymbol, _) if sym.isCaseClass =>
28 | Right(PreparationResult(bind, sym, caseDef.body))
29 | case _ => failure
30 | }
31 | }
32 | }
33 |
34 | res getOrElse failure
35 | }
36 |
37 | def perform(selection: Selection, preparationResult: PreparationResult): Either[RefactoringError, List[Change]] = {
38 |
39 | val PreparationResult(bind, sym, body) = preparationResult
40 |
41 | val apply = {
42 |
43 | val argSymbols = sym.info.decls.toList filter (s => s.isCaseAccessor && s.isMethod)
44 |
45 | val argIdents = argSymbols map (s => Ident(s.name))
46 | val isTuple = sym.nameString.matches("Tuple\\d+")
47 |
48 | if(isTuple) {
49 | // don't insert TupleX in the code
50 | Apply(Ident(""), argIdents)
51 | } else {
52 | Apply(Ident(sym.name), argIdents)
53 | }
54 | }
55 |
56 | val replacement = {
57 | // we create a mini-index of the casedef-body
58 | val index = GlobalIndex(body)
59 | val nameIsNotReferenced = index.references(bind.symbol).isEmpty
60 |
61 | if(nameIsNotReferenced) {
62 | apply
63 | } else {
64 | // creates a `name @ ` in the code
65 | bind copy (body = apply)
66 | }
67 | } replaces bind
68 |
69 | Right(refactor(List(replacement)))
70 | }
71 |
72 | def index: IndexLookup = sys.error("")
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/ExplicitGettersSetters.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package implementations
7 |
8 | import common.Change
9 | import scala.tools.nsc.ast.parser.Tokens
10 | import scala.tools.nsc.symtab.Flags
11 |
12 | abstract class ExplicitGettersSetters extends MultiStageRefactoring with ParameterlessRefactoring with common.InteractiveScalaCompiler {
13 |
14 | import global._
15 |
16 | type PreparationResult = ValDef
17 |
18 | def prepare(s: Selection) = {
19 | s.findSelectedOfType[ValDef] match {
20 | case Some(valdef) => Right(valdef)
21 | case None => Left(new PreparationError("no valdef selected"))
22 | }
23 | }
24 |
25 | override def perform(selection: Selection, selectedValue: PreparationResult): Either[RefactoringError, List[Change]] = {
26 |
27 | val template = selection.findSelectedOfType[Template].getOrElse {
28 | return Left(RefactoringError("no template found"))
29 | }
30 |
31 | val createSetter = selectedValue.symbol.isMutable
32 |
33 | val publicName = selectedValue.name.toString.trim
34 |
35 | val privateName = "_"+ publicName
36 |
37 | val privateFieldMods = if(createSetter)
38 | Modifiers(Flags.PARAMACCESSOR).
39 | withPosition (Flags.PRIVATE, NoPosition).
40 | withPosition (Tokens.VAR, NoPosition)
41 | else
42 | Modifiers(Flags.PARAMACCESSOR)
43 |
44 | val privateField = selectedValue copy (mods = privateFieldMods, name = newTermName(privateName))
45 |
46 | val getter = DefDef(
47 | mods = Modifiers(Flags.METHOD) withPosition (Flags.METHOD, NoPosition),
48 | name = newTermName(publicName),
49 | tparams = Nil,
50 | vparamss = Nil,
51 | tpt = EmptyTree,
52 | rhs = Block(
53 | Ident(privateName) :: Nil, EmptyTree))
54 |
55 | val setter = DefDef(
56 | mods = Modifiers(Flags.METHOD) withPosition (Flags.METHOD, NoPosition),
57 | name = newTermName(publicName +"_="),
58 | tparams = Nil,
59 | vparamss = List(List(ValDef(Modifiers(Flags.PARAM), newTermName(publicName), TypeTree(selectedValue.tpt.tpe), EmptyTree))),
60 | tpt = EmptyTree,
61 | rhs = Block(
62 | Assign(
63 | Ident(privateName),
64 | Ident(publicName)) :: Nil, EmptyTree))
65 |
66 | val insertGetterSettersTransformation = transform {
67 |
68 | case tpl: Template if tpl == template =>
69 |
70 | val classParameters = tpl.body.map {
71 | case t: ValDef if t == selectedValue => privateField setPos t.pos
72 | case t => t
73 | }
74 |
75 | val body = if(createSetter)
76 | getter :: setter :: classParameters
77 | else
78 | getter :: classParameters
79 |
80 | tpl.copy(body = body) setPos tpl.pos
81 | }
82 |
83 | Right(transformFile(selection.file, topdown(matchingChildren(insertGetterSettersTransformation))))
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/ExtractMethod.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package implementations
7 |
8 | import scala.tools.nsc.symtab.Flags
9 |
10 | import analysis.Indexes
11 | import analysis.TreeAnalysis
12 | import common.Change
13 | import common.InteractiveScalaCompiler
14 | import transformation.TreeFactory
15 |
16 | abstract class ExtractMethod extends MultiStageRefactoring with TreeAnalysis with Indexes with TreeFactory with InteractiveScalaCompiler {
17 |
18 | val global: tools.nsc.interactive.Global
19 | import global._
20 |
21 | type PreparationResult = Tree
22 |
23 | type RefactoringParameters = String
24 |
25 | def prepare(s: Selection) = {
26 | s.findSelectedOfType[DefDef] match {
27 | case _ if s.selectedTopLevelTrees.isEmpty =>
28 | Left(PreparationError("No expressions or statements selected."))
29 | case Some(tree) =>
30 | Right(tree)
31 | case None =>
32 | Left(PreparationError("No enclosing method definition found: please select code that's inside a method."))
33 | }
34 | }
35 |
36 | def perform(selection: Selection, selectedMethod: PreparationResult, methodName: RefactoringParameters): Either[RefactoringError, List[Change]] = {
37 |
38 | val (call, newDef) = {
39 |
40 | val deps = {
41 | val inboundDeps = inboundLocalDependencies(selection, selectedMethod.symbol)
42 | selection.selectedTopLevelTrees match {
43 | /* extracting the condition of a for-expression */
44 | case List(t: Function) if t.pos.isTransparent =>
45 | t.vparams.map(_.symbol) ::: inboundDeps
46 | case _ =>
47 | inboundDeps
48 | }
49 | }
50 |
51 | val parameters = {
52 | if(deps.isEmpty)
53 | Nil // no argument list
54 | else
55 | deps :: Nil // single argument list with all parameters
56 | }
57 |
58 | val returns = outboundLocalDependencies(selection)
59 |
60 | val returnStatement = if(returns.isEmpty) Nil else mkReturn(returns) :: Nil
61 |
62 | val newDef = mkDefDef(NoMods withPosition (Flags.PRIVATE, NoPosition), methodName, parameters, selection.selectedTopLevelTrees ::: returnStatement)
63 |
64 | val call = mkCallDefDef(methodName, deps :: Nil, returns)
65 |
66 | (call, newDef)
67 | }
68 |
69 | val extractSingleStatement = selection.selectedTopLevelTrees.size == 1
70 |
71 | val findTemplate = filter {
72 | case Template(_, _, body) =>
73 | body exists (_ == selectedMethod)
74 | }
75 |
76 | val findMethod = filter {
77 | case d: DefDef => d == selectedMethod
78 | }
79 |
80 | val replaceBlockOfStatements = topdown {
81 | matchingChildren {
82 | transform {
83 | case block @ BlockExtractor(stats) if stats.nonEmpty => {
84 | val newStats = stats.replaceSequence(selection.selectedTopLevelTrees, call :: Nil)
85 | mkBlock(newStats) replaces block
86 | }
87 | }
88 | }
89 | }
90 |
91 | val replaceExpression = if(extractSingleStatement)
92 | replaceTree(selection.selectedTopLevelTrees.head, call)
93 | else
94 | fail[Tree]
95 |
96 | val insertMethodCall = transform {
97 | case tpl @ Template(_, _, body) =>
98 | val p = selectedMethod.pos.point
99 | val (before, after) = body.span { t =>
100 | !t.pos.isRange /* to skip synthetic methods*/ || t.pos.point <= p
101 | }
102 | tpl copy(body = before ::: newDef :: after) replaces tpl
103 | }
104 |
105 | val extractMethod = topdown {
106 | matchingChildren {
107 | findTemplate &>
108 | topdown {
109 | matchingChildren {
110 | findMethod &> replaceBlockOfStatements |> replaceExpression
111 | }
112 | } &>
113 | insertMethodCall
114 | }
115 | }
116 |
117 | Right(transformFile(selection.file, extractMethod))
118 | }
119 | }
120 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/GenerateHashcodeAndEquals.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations
3 |
4 | /**
5 | * Refactoring that generates hashCode and equals implementations
6 | * following the recommendations given in chapter 28 of
7 | * Programming in Scala.
8 | */
9 | abstract class GenerateHashcodeAndEquals extends ClassParameterDrivenSourceGeneration {
10 |
11 | import global._
12 |
13 | override def failsBecause(classDef: ClassDef) = {
14 | None
15 | }
16 |
17 | override def sourceGeneration(selectedParams: List[ValDef], preparationResult: PreparationResult, refactoringParams: RefactoringParameters) = {
18 |
19 | val equalityMethods = mkEqualityMethods(preparationResult.classDef.symbol, selectedParams, refactoringParams.callSuper)
20 | val newParents = newParentNames(preparationResult.classDef, selectedParams).map(name => Ident(newTermName(name)))
21 |
22 | val equalityMethodNames = List(nme.equals_, nme.hashCode_, nme.canEqual_).map(_.toString)
23 | def isEqualityMethod(t: Tree) = t match {
24 | case d: ValOrDefDef => equalityMethodNames contains d.nameString
25 | case _ => false
26 | }
27 |
28 | def addEqualityMethods = transform {
29 | case t @ Template(parents, self, body) => {
30 | val bodyFilter: Tree => Boolean = if (refactoringParams.keepExistingEqualityMethods)
31 | (t: Tree) => true else (t: Tree) => !isEqualityMethod(t)
32 | val filteredBody = body.filter(bodyFilter)
33 | val equalityMethodsInBody = filteredBody collect {case d: ValOrDefDef if equalityMethodNames contains d.nameString => d.name }
34 | val filteredEqualityMethods = equalityMethods.filter(e => !(equalityMethodsInBody contains e.name))
35 | Template(parents:::newParents, self, filteredBody:::filteredEqualityMethods) replaces t
36 | }
37 | }
38 |
39 | addEqualityMethods
40 | }
41 |
42 | private def mkEqualityMethods(classSymbol: Symbol, params: List[ValDef], callSuper: Boolean) = {
43 | val canEqual = mkCanEqual(classSymbol)
44 | val hashcode = mkHashcode(classSymbol, params, callSuper)
45 | val equals = mkEquals(classSymbol, params, callSuper)
46 |
47 | canEqual::equals::hashcode::Nil
48 | }
49 |
50 | protected def newParentNames(classDef: ClassDef, selectedParams: List[ValDef]): List[String] = {
51 | val existingParents = classDef.impl.parents.map(_.nameString)
52 | if(existingParents contains "Equals")
53 | Nil
54 | else
55 | "Equals"::Nil
56 | }
57 |
58 | }
59 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/ImportsHelper.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations
3 |
4 | import scala.language.reflectiveCalls
5 | import scala.tools.refactoring.common.TracingImpl
6 | import scala.tools.refactoring.implementations.oimports.OrganizeImportsWorker
7 |
8 | trait ImportsHelper extends TracingImpl {
9 |
10 | self: common.InteractiveScalaCompiler with analysis.Indexes with transformation.Transformations with transformation.TreeTransformations with common.EnrichedTrees =>
11 |
12 | import global._
13 |
14 | def addRequiredImports(importsUser: Option[Tree], targetPackageName: Option[String]) = traverseAndTransformAll {
15 | findBestPackageForImports &> transformation[(PackageDef, List[Import], List[Tree]), Tree] {
16 | case (pkg, existingImports, rest) => {
17 | val user = importsUser getOrElse pkg
18 | val targetPkgName = targetPackageName getOrElse pkg.nameString
19 | val oi = new OrganizeImportsWorker[global.type](global) {
20 | import participants._
21 | import regionContext._
22 | import transformations._
23 | import treeToolbox._
24 | object NeededImports extends Participant {
25 | def doApply(trees: List[Import]) = {
26 | val externalDependencies = neededImports(user) filterNot { imp =>
27 | // We don't want to add imports for types that are
28 | // children of `importsUser`.
29 | index.declaration(imp.symbol).exists { declaration =>
30 | val sameFile = declaration.pos.source.file.canonicalPath == user.pos.source.file.canonicalPath
31 | def userPosIncludesDeclPos = user.pos.includes(declaration.pos)
32 | sameFile && userPosIncludesDeclPos
33 | }
34 | }
35 |
36 | val newImportsToAdd = externalDependencies filterNot {
37 | case Select(qualifier, name) =>
38 | val depPkgStr = importAsString(qualifier)
39 | val depNameStr = "" + name
40 |
41 | trees exists {
42 | case Import(expr, selectors) =>
43 | val impPkgStr = importAsString(expr)
44 |
45 | selectors exists { selector =>
46 | val selNameStr = "" + selector.name
47 | val selRenameStr = "" + selector.rename
48 |
49 | impPkgStr == depPkgStr && {
50 | selector.name == nme.WILDCARD || {
51 | selNameStr == depNameStr || selRenameStr == depNameStr
52 | }
53 | }
54 | }
55 | }
56 | }
57 | val existingStillNeededImports = new recomputeAndModifyUnused(user)(
58 | List(Region(trees.map { imp => new RegionImport(proto = imp)() })))
59 |
60 | existingStillNeededImports.flatMap { _.imports } ::: SortImports(mkImportTrees(newImportsToAdd, targetPkgName))
61 | }
62 | }
63 | }
64 |
65 | val imports = oi.NeededImports(existingImports).filterNot { imp =>
66 | val noLongerNeeded = targetPkgName == imp.expr.toString && imp.selectors.size == 1 && {
67 | // Note that we don't touch imports with multiple selectors here. This limitation,
68 | // that should not result in any regressions, might be addressed in the future.
69 | val s = imp.selectors.head
70 | s.name == s.rename
71 | }
72 |
73 | noLongerNeeded
74 | }
75 |
76 | // When we move the whole file, we only want to add imports to the originating package
77 | pkg copy (stats = imports ::: rest) replaces pkg
78 | }
79 | }
80 | }
81 |
82 | }
83 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/InlineLocal.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package implementations
7 |
8 | import common.Change
9 | import transformation.TreeFactory
10 | import analysis.Indexes
11 |
12 | abstract class InlineLocal extends MultiStageRefactoring with ParameterlessRefactoring with TreeFactory with Indexes with common.InteractiveScalaCompiler {
13 |
14 | import global._
15 |
16 | override type PreparationResult = ValDef
17 |
18 | override def prepare(s: Selection) = {
19 |
20 | val selectedValue = s.findSelectedOfType[RefTree] match {
21 | case Some(t) =>
22 | index.declaration(t.symbol) match {
23 | case Some(v: ValDef) => Some(v)
24 | case _ => None
25 | }
26 | case None => s.findSelectedOfType[ValDef]
27 | }
28 |
29 | def isInliningAllowed(sym: Symbol) =
30 | (sym.isPrivate || sym.isLocal) && !sym.isMutable && !sym.isValueParameter
31 |
32 | selectedValue match {
33 | case Some(t) if isInliningAllowed(t.symbol) =>
34 | Right(t)
35 | case Some(t) =>
36 | Left(PreparationError("The selected value cannot be inlined."))
37 | case None =>
38 | Left(PreparationError("No local value selected."))
39 | }
40 | }
41 |
42 | override def perform(selection: Selection, selectedValue: PreparationResult): Either[RefactoringError, List[Change]] = {
43 |
44 | trace("Selected: %s", selectedValue)
45 |
46 | val removeSelectedValue = {
47 |
48 | def replaceSelectedValue(ts: List[Tree]) = {
49 | ts replaceSequence (List(selectedValue), Nil)
50 | }
51 |
52 | transform {
53 | case tpl @ Template(_, _, stats) if stats contains selectedValue =>
54 | tpl.copy(body = replaceSelectedValue(stats)) replaces tpl
55 | case block @ BlockExtractor(stats) if stats contains selectedValue =>
56 | mkBlock(replaceSelectedValue(stats)) replaces block
57 | }
58 | }
59 |
60 | val references = index references selectedValue.symbol
61 |
62 | val replaceReferenceWithRhs = {
63 |
64 | val replacement = selectedValue.rhs match {
65 | // inlining `list.filter _` should not include the `_`
66 | case Function(vparams, Apply(fun, args)) if vparams forall (_.symbol.isSynthetic) => fun
67 | case t => t
68 | }
69 |
70 | trace("Value is referenced on lines: %s", references map (_.pos.lineContent) mkString "\n ")
71 |
72 | transform {
73 | case t if references contains t => replacement
74 | }
75 | }
76 |
77 | if(references.isEmpty) {
78 | Left(RefactoringError("No references to selected val found."))
79 | } else {
80 | Right(transformFile(selection.file, topdown(matchingChildren(removeSelectedValue &> topdown(matchingChildren(replaceReferenceWithRhs))))))
81 | }
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/IntroduceProductNTrait.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations
3 |
4 |
5 | /**
6 | * Refactoring that implements the ProductN trait for a class.
7 | * Given N selected class parameters this refactoring generates
8 | * the methods needed to implement the ProductN trait. This includes
9 | * implementations for hashCode and equals.
10 | * @see GenerateHashcodeAndEquals
11 | */
12 | abstract class IntroduceProductNTrait extends GenerateHashcodeAndEquals {
13 |
14 | import global._
15 |
16 | override def sourceGeneration(selectedParams: List[ValDef], preparationResult: PreparationResult, refactoringParams: RefactoringParameters) = {
17 | val superGeneration = super.sourceGeneration(selectedParams, preparationResult, refactoringParams)
18 |
19 | val projections = {
20 | def makeElemProjection(elem: ValDef, pos: Int) = {
21 | val body = List(Ident(elem.name))
22 | mkDefDef(name = "_" + pos, body = body)
23 | }
24 |
25 | selectedParams.zipWithIndex.map(t => makeElemProjection(t._1, t._2 + 1))
26 | }
27 |
28 | def addProductTrait = transform ({
29 | case t @ Template(_, _, body) => t.copy(body = projections:::body) replaces t
30 | })
31 |
32 | superGeneration &> addProductTrait
33 | }
34 |
35 | override def newParentNames(classDef: ClassDef, selectedParams: List[ValDef]) = {
36 | val arity = selectedParams.length
37 | val paramsTypenames = selectedParams.map(v => v.tpt.nameString)
38 | val productParent = "Product" + arity + "[" + paramsTypenames.mkString(", ") + "]"
39 | productParent::Nil
40 | }
41 |
42 | }
43 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/MergeParameterLists.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations
3 |
4 | /**
5 | * Refactoring to merge parameter lists of a method.
6 | */
7 | abstract class MergeParameterLists extends MethodSignatureRefactoring {
8 |
9 | import global._
10 |
11 | type MergePositions = List[Int]
12 | type RefactoringParameters = MergePositions
13 |
14 | override def checkRefactoringParams(prep: PreparationResult, affectedDefs: AffectedDefs, params: RefactoringParameters) = {
15 | val selectedDefDef = prep.defdef
16 | val allowedMergeIndexesRange = 1 until selectedDefDef.vparamss.size
17 | val isNotEmpty = (p: RefactoringParameters) => !p.isEmpty
18 | val isSorted = (p: RefactoringParameters) => (p sortWith (_ < _)) == p
19 | val uniqueIndexes = (p: RefactoringParameters) => p.distinct == p
20 | val indexesInRange = (p: RefactoringParameters) => allowedMergeIndexesRange containsSlice (p.head to p.last)
21 | val mergeable = (p: RefactoringParameters) => {
22 | val allAffectedDefs = affectedDefs.originals:::affectedDefs.partials
23 | val preparedParams = allAffectedDefs.map(prepareParamsForSingleRefactoring(params, selectedDefDef, _))
24 | preparedParams.filter(_ contains 0).isEmpty
25 | }
26 | val allConditions = List(isNotEmpty, isSorted, uniqueIndexes, indexesInRange, mergeable)
27 | allConditions.foldLeft(true)((b, f) => b && f(params))
28 | }
29 |
30 | override def defdefRefactoring(params: RefactoringParameters) = transform {
31 | case orig @ DefDef(mods, name, tparams, vparamss, tpt, rhs) => {
32 | val vparamssWithIndex = vparamss.zipWithIndex
33 | val mergedVparamss = vparamssWithIndex.foldLeft(Nil: List[List[ValDef]])((acc, current) => current match {
34 | case (_, index) if params contains index => (acc.head:::current._1)::acc.tail
35 | case _ => current._1::acc
36 | }).reverse
37 | DefDef(mods, name, tparams, mergedVparamss, tpt, rhs) replaces orig
38 | }
39 | }
40 |
41 | override def applyRefactoring(params: RefactoringParameters) = transform {
42 | case apply @ Apply(fun, args) => {
43 | val originalTree = findOriginalTree(apply)
44 | val pos = paramListPos(originalTree) - 1
45 | if(params contains pos) {
46 | fun match {
47 | case Apply(ffun, fargs) => Apply(ffun, fargs:::args)
48 | case Select(Apply(ffun, fargs), name) => Apply(ffun, fargs:::args)
49 | case _ => apply
50 | }
51 | } else {
52 | apply
53 | }
54 | }
55 | }
56 |
57 | override def traverseApply(t: ⇒ Transformation[Tree, Tree]) = bottomup(t)
58 |
59 | override def prepareParamsForSingleRefactoring(originalParams: RefactoringParameters, selectedMethod: DefDef, toRefactor: DefInfo): RefactoringParameters = {
60 | val originalNrParamLists = selectedMethod.vparamss.size
61 | val currentNrParamLists = toRefactor.nrParamLists
62 | val untouchables = toRefactor.nrUntouchableParamLists
63 | val toShift = originalNrParamLists - currentNrParamLists - untouchables
64 | originalParams.map(_ - toShift).filter(_ >= untouchables)
65 | }
66 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/SplitParameterLists.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations
3 |
4 | /**
5 | * Refactoring to split parameter lists of a method.
6 | */
7 | abstract class SplitParameterLists extends MethodSignatureRefactoring {
8 |
9 | import global._
10 |
11 | type SplitPositions = List[Int]
12 | /**
13 | * Split positions must be provided for every parameter list of the method (though they can be Nil)
14 | */
15 | type RefactoringParameters = List[SplitPositions]
16 |
17 | override def checkRefactoringParams(prep: PreparationResult, affectedDefs: AffectedDefs, params: RefactoringParameters) = {
18 | def checkRefactoringParamsHelper(vparamss: List[List[ValDef]], sectionss: List[SplitPositions]): Boolean = {
19 | val sortedSections = sectionss.map(Set(_: _*).toList.sorted)
20 | if(sortedSections != sectionss || vparamss.size != sectionss.size) {
21 | false
22 | } else {
23 | val emptyRange = 1 to 0
24 | val sectionRanges = sectionss map { case Nil => emptyRange ; case s => s.head to s.last }
25 | val vparamsRanges = vparamss.map(1 until _.size)
26 | (vparamsRanges zip sectionRanges).foldLeft(true)((b, ranges) => b && (ranges._1 containsSlice ranges._2))
27 | }
28 | }
29 |
30 | checkRefactoringParamsHelper(prep.defdef.vparamss, params)
31 | }
32 |
33 | def splitSingleParamList[T](origVparams: List[T], positions: SplitPositions): List[List[T]] = {
34 | val nrParamsPerList = (positions:::List(origVparams.length) zip 0::positions) map (t => t._1 - t._2)
35 | nrParamsPerList.foldLeft((Nil: List[List[T]] , origVparams))((acc, nrParams) => {
36 | val (currentCurriedParamList, remainingOrigParams) = acc._2 splitAt nrParams
37 | (acc._1:::List(currentCurriedParamList), remainingOrigParams)
38 | })._1
39 | }
40 |
41 | def makeSplitApply(baseFun: Tree, vparamss: List[List[Tree]]) = {
42 | val firstApply = Apply(baseFun, vparamss.headOption.getOrElse(throw new IllegalArgumentException("can't handle empty vparamss")))
43 | vparamss.tail.foldLeft(firstApply)((fun, vparams) => Apply(fun, vparams))
44 | }
45 |
46 | override def defdefRefactoring(params: RefactoringParameters) = transform {
47 | case orig @ DefDef(mods, name, tparams, vparamss, tpt, rhs) => {
48 | val split = (vparamss zip params) flatMap (l => splitSingleParamList(l._1, l._2))
49 | DefDef(mods, name, tparams, split, tpt, rhs) replaces orig
50 | }
51 | }
52 |
53 | override def applyRefactoring(params: RefactoringParameters) = transform {
54 | case apply @ Apply(fun, args) => {
55 | val originalTree = findOriginalTree(apply)
56 | val pos = paramListPos(originalTree) - 1
57 | val splitParamLists = splitSingleParamList(apply.args, params(pos))
58 | val splitApply = makeSplitApply(fun, splitParamLists)
59 | splitApply replaces apply
60 | }
61 | }
62 |
63 | override def traverseApply(t: ⇒ Transformation[Tree, Tree]) = bottomup(t)
64 |
65 | override def prepareParamsForSingleRefactoring(originalParams: RefactoringParameters, selectedMethod: DefDef, toRefactor: DefInfo): RefactoringParameters = {
66 | val toDrop = originalParams.size - toRefactor.nrParamLists
67 | val preparedParams = originalParams.drop(toDrop)
68 | val noSplitters = (1 to toRefactor.nrUntouchableParamLists).toList
69 | noSplitters match {
70 | case Nil => preparedParams
71 | case _ => noSplitters.map(_ => Nil):::preparedParams
72 | }
73 |
74 |
75 | }
76 |
77 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/extraction/ExtractCode.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.implementations.extraction
2 |
3 | /**
4 | * General extraction refactoring that proposes different concrete extractions based on the
5 | * current selection.
6 | */
7 | abstract class ExtractCode extends ExtractionRefactoring with AutoExtractions {
8 | val collector = AutoExtraction
9 | }
10 |
11 | trait AutoExtractions extends MethodExtractions with ValueExtractions with ExtractorExtractions with ParameterExtractions {
12 | /**
13 | * Proposes different kinds of extractions.
14 | */
15 | object AutoExtraction extends ExtractionCollector[Extraction] {
16 | /**
17 | * Extraction collectors used for auto extraction.
18 | */
19 | val availableCollectors =
20 | ExtractorExtraction ::
21 | ValueExtraction ::
22 | MethodExtraction ::
23 | ParameterExtraction ::
24 | Nil
25 |
26 | /**
27 | * Searches for an extraction source that is valid for at least one
28 | * extraction collector. If an appropriate source is found it calls
29 | * `collect(s)` on every collector that accepts this source.
30 | */
31 | override def collect(s: Selection) = {
32 | var applicableCollectors: List[ExtractionCollector[Extraction]] = Nil
33 | val sourceOpt = s.expand.expandTo { source: Selection =>
34 | applicableCollectors = availableCollectors.filter(_.isValidExtractionSource(source))
35 | !applicableCollectors.isEmpty
36 | }
37 |
38 | val extractions = applicableCollectors.flatMap { collector =>
39 | collector.collect(sourceOpt.get).right.getOrElse(Nil)
40 | }
41 |
42 | if (extractions.isEmpty)
43 | Left("No applicable extraction found.")
44 | else
45 | Right(extractions.sortBy(-_.extractionTarget.enclosing.pos.startOrPoint))
46 | }
47 |
48 | def isValidExtractionSource(s: Selection) = ???
49 |
50 | def createExtractions(source: Selection, targets: List[ExtractionTarget], name: String) = ???
51 | }
52 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/extraction/ExtractParameter.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.implementations.extraction
2 |
3 | import scala.tools.refactoring.analysis.ImportAnalysis
4 |
5 | abstract class ExtractParameter extends ExtractionRefactoring with ParameterExtractions {
6 | val collector = ParameterExtraction
7 | }
8 |
9 | /**
10 | * Extracts an expression into a new parameter whose default value is
11 | * the extracted expression.
12 | */
13 | trait ParameterExtractions extends Extractions with ImportAnalysis {
14 | import global._
15 |
16 | object ParameterExtraction extends ExtractionCollector[ParameterExtraction] {
17 | def isValidExtractionSource(s: Selection) =
18 | s.representsValue && !s.representsParameter
19 |
20 | override def createInsertionPosition(s: Selection) =
21 | atEndOfValueParameterList
22 |
23 | def createExtractions(source: Selection, targets: List[ExtractionTarget], name: String) = {
24 | val validTargets = targets.takeWhile { t =>
25 | source.inboundLocalDeps.forall(t.scope.sees(_))
26 | }
27 |
28 | validTargets.map(ParameterExtraction(source, _, name))
29 | }
30 | }
31 |
32 | case class ParameterExtraction(
33 | extractionSource: Selection,
34 | extractionTarget: ExtractionTarget,
35 | abstractionName: String) extends Extraction {
36 |
37 | val displayName = extractionTarget.enclosing match {
38 | case t: DefDef => s"Extract Parameter to Method ${t.symbol.nameString}"
39 | }
40 |
41 | val functionOrDefDef = extractionTarget.enclosing
42 |
43 | def perform() = {
44 | val tpe = defaultVal.tpe
45 | val param = mkParam(abstractionName, tpe, defaultVal)
46 | val paramRef = Ident(newTermName(abstractionName))
47 |
48 | extractionSource.replaceBy(paramRef) ::
49 | extractionTarget.insert(param) ::
50 | Nil
51 | }
52 |
53 | val defaultVal = {
54 | extractionSource.selectedTopLevelTrees.last
55 | }
56 |
57 | def withAbstractionName(name: String) =
58 | copy(abstractionName = name).asInstanceOf[this.type]
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/extraction/ExtractValue.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.implementations.extraction
2 |
3 | import scala.tools.refactoring.analysis.ImportAnalysis
4 |
5 | /**
6 | * Extracts one or more expressions into a new val definition.
7 | */
8 | abstract class ExtractValue extends ExtractionRefactoring with ValueExtractions {
9 | val collector = ValueExtraction
10 | }
11 |
12 | trait ValueExtractions extends Extractions with ImportAnalysis {
13 | import global._
14 |
15 | object ValueExtraction extends ExtractionCollector[ValueExtraction] {
16 | def isValidExtractionSource(s: Selection) =
17 | (s.representsValue || s.representsValueDefinitions) && !s.representsParameter
18 |
19 | def createExtractions(source: Selection, targets: List[ExtractionTarget], name: String) = {
20 | val validTargets = targets.takeWhile { t =>
21 | source.inboundLocalDeps.forall(t.scope.sees(_))
22 | }
23 |
24 | validTargets.map(ValueExtraction(source, _, name))
25 | }
26 | }
27 |
28 | case class ValueExtraction(
29 | extractionSource: Selection,
30 | extractionTarget: ExtractionTarget,
31 | abstractionName: String) extends Extraction {
32 |
33 | val displayName = extractionTarget.enclosing match {
34 | case t: Template => s"Extract Value to ${t.symbol.owner.decodedName}"
35 | case _ => "Extract Local Value"
36 | }
37 |
38 | def withAbstractionName(name: String) =
39 | copy(abstractionName = name).asInstanceOf[this.type]
40 |
41 | lazy val imports = buildImportTree(extractionSource.root)
42 |
43 | def perform() = {
44 | val outboundDeps = extractionSource.outboundLocalDeps
45 | val call = mkCallValDef(abstractionName, outboundDeps)
46 |
47 | val returnStatements =
48 | if (outboundDeps.isEmpty) Nil
49 | else mkReturn(outboundDeps) :: Nil
50 |
51 | val importStatements = extractionSource.selectedTopLevelTrees.flatMap(imports.findRequiredImports(_, extractionSource.pos, extractionTarget.pos))
52 |
53 | val extractedStatements = extractionSource.selectedTopLevelTrees match {
54 | case (fn @ Function(vparams, Block(stmts, expr))) :: Nil =>
55 | val tpe = fn.tpe
56 | val newFn = fn copy (body = Block(stmts, expr))
57 | newFn.tpe = tpe
58 | newFn :: Nil
59 | case ts => ts
60 | }
61 |
62 | val statements = importStatements ::: extractedStatements ::: returnStatements
63 |
64 | val abstraction = statements match {
65 | case expr :: Nil =>
66 | expr match {
67 | // Add explicit type annotations for extracted functions
68 | case fn: Function =>
69 | mkValDef(abstractionName, fn, TypeTree(fn.tpe))
70 | case t =>
71 | mkValDef(abstractionName, t)
72 | }
73 | case stmts => mkValDef(abstractionName, mkBlock(stmts))
74 | }
75 |
76 | extractionSource.replaceBy(call, preserveHierarchy = true) ::
77 | extractionTarget.insert(abstraction) ::
78 | Nil
79 | }
80 | }
81 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/implementations/oimports/ImportsOrganizer.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package implementations.oimports
3 |
4 | import scala.tools.nsc.Global
5 |
6 | import sourcegen.Formatting
7 |
8 | abstract class ImportsOrganizer[G <: Global, U <: TreeToolbox[G]](val treeToolbox: U) {
9 | import treeToolbox.global._
10 | type T <: Tree
11 |
12 | protected def noAnyTwoImportsInSameLine(importsGroup: List[Import]): Boolean =
13 | importsGroup.size == importsGroup.map { _.pos.line }.distinct.size
14 |
15 | private def importsGroupsFromTree(trees: List[Tree]): List[List[Import]] = {
16 | val groupedImports = trees.foldLeft(List.empty[List[Import]]) { (acc, tree) =>
17 | tree match {
18 | case imp: Import =>
19 | val lastUpdated = acc.lastOption.map { _ :+ imp }.getOrElse(List(imp))
20 | acc.take(acc.length - 1) :+ lastUpdated
21 | case _ => acc :+ List.empty[Import]
22 | }
23 | }.filter { _.nonEmpty }
24 | groupedImports
25 | }
26 |
27 | protected def forTreesOf(tree: Tree): List[(T, Symbol)]
28 |
29 | protected def treeChildren(parent: T): List[Tree]
30 |
31 | private def toRegions(groupedImports: List[List[Import]], importsOwner: Symbol, formatting: Formatting): List[treeToolbox.Region] =
32 | groupedImports.collect {
33 | case imports @ h :: _ => RegionBuilder[G, U](treeToolbox)(imports, importsOwner, formatting, "")
34 | }.flatten
35 |
36 | def transformTreeToRegions(tree: Tree, formatting: Formatting): List[treeToolbox.Region] = forTreesOf(tree).flatMap {
37 | case (extractedTree, treeOwner) =>
38 | val groupedImports = importsGroupsFromTree(treeChildren(extractedTree)).filter {
39 | noAnyTwoImportsInSameLine
40 | }
41 | toRegions(groupedImports, treeOwner, formatting)
42 | }
43 | }
44 |
45 | class DefImportsOrganizer[G <: Global, U <: TreeToolbox[G]](override val treeToolbox: U) extends ImportsOrganizer[G, U](treeToolbox) {
46 | import treeToolbox.global._
47 | type T = Block
48 | import treeToolbox.forTreesOfKind
49 |
50 | override protected def forTreesOf(tree: Tree) = forTreesOfKind[Block](tree) { treeCollector =>
51 | {
52 | case b @ Block(stats, expr) if treeCollector.currentOwner.isMethod && !treeCollector.currentOwner.isLazy =>
53 | treeCollector.collect(b)
54 | stats.foreach { treeCollector.traverse }
55 | treeCollector.traverse(expr)
56 | }
57 | }
58 |
59 | override protected def treeChildren(block: Block) = block.stats
60 | }
61 |
62 | class ClassDefImportsOrganizer[G <: Global, U <: TreeToolbox[G]](override val treeToolbox: U) extends ImportsOrganizer[G, U](treeToolbox) {
63 | import treeToolbox.global._
64 | type T = Template
65 | import treeToolbox.forTreesOfKind
66 |
67 | override protected def forTreesOf(tree: Tree) = forTreesOfKind[Template](tree) { treeCollector =>
68 | {
69 | case t @ Template(_, _, body) =>
70 | treeCollector.collect(t)
71 | body.foreach { treeCollector.traverse }
72 | }
73 | }
74 |
75 | override protected def treeChildren(template: Template) = template.body
76 | }
77 |
78 | class PackageDefImportsOrganizer[G <: Global, U <: TreeToolbox[G]](override val treeToolbox: U) extends ImportsOrganizer[G, U](treeToolbox) {
79 | import treeToolbox.global._
80 | type T = PackageDef
81 | import treeToolbox.forTreesOfKind
82 |
83 | override protected def forTreesOf(tree: Tree) = forTreesOfKind[PackageDef](tree) { treeCollector =>
84 | {
85 | case p @ PackageDef(pid, stats) =>
86 | treeCollector.collect(p, p.symbol.asTerm.referenced)
87 | stats.foreach { treeCollector.traverse }
88 | }
89 | }
90 |
91 | override protected def noAnyTwoImportsInSameLine(importsGroup: List[Import]): Boolean = true
92 | override protected def treeChildren(packageDef: PackageDef) = packageDef.stats
93 | }
94 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/package.scala:
--------------------------------------------------------------------------------
1 | package scala.tools
2 |
3 | import scala.tools.nsc.interactive.PresentationCompilerThread
4 |
5 | package object refactoring {
6 |
7 | /**
8 | * Asserts that the current operation is running on the thread
9 | * of the presentation compiler (PC). This is necessary because many
10 | * operations on compiler symbols can trigger further compilation,
11 | * which needs to be done on the PC thread.
12 | *
13 | * To run an operation on the PC thread, use global.ask { .. }
14 | */
15 | def assertCurrentThreadIsPresentationCompiler(): Unit = {
16 | val msg = "operation should be running on the presentation compiler thread"
17 | assert(Thread.currentThread.isInstanceOf[PresentationCompilerThread], msg)
18 | }
19 |
20 | /**
21 | * Safe way to get a simple class name from an object.
22 | *
23 | * Using getClass.getSimpleName can sometimes lead to InternalError("Malformed class name")
24 | * being thrown, so we catch that. Probably related to #SI-2034
25 | */
26 | def getSimpleClassName(o: Object): String = try {
27 | o.getClass.getSimpleName
28 | } catch {
29 | case _: InternalError | _: NoClassDefFoundError => o.getClass.getName
30 | }
31 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/sourcegen/AbstractPrinter.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package sourcegen
7 |
8 | import scala.reflect.internal.util.SourceFile
9 |
10 | trait AbstractPrinter extends CommonPrintUtils {
11 |
12 | this: common.Tracing with common.EnrichedTrees with Indentations with common.CompilerAccess with Formatting =>
13 |
14 | import global._
15 |
16 | /**
17 | * PrintingContext is passed around with all the print methods and contains
18 | * the context or environment for the current printing.
19 | */
20 | case class PrintingContext(ind: Indentation, changeSet: ChangeSet, parent: Tree, file: Option[SourceFile]) {
21 | lazy val newline: String = {
22 | if(file.exists(_.content.containsSlice("\r\n")))
23 | "\r\n"
24 | else
25 | "\n"
26 | }
27 | }
28 |
29 | trait ChangeSet {
30 | def hasChanged(t: Tree): Boolean
31 | }
32 |
33 | def print(t: Tree, ctx: PrintingContext): Fragment
34 |
35 | }
36 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/sourcegen/CommentsUtils.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.sourcegen
2 |
3 | /**
4 | * Only here for backward compatibility - use [[SourceUtils]] instead
5 | */
6 | object CommentsUtils extends SourceUtils
7 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/sourcegen/CommonPrintUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package sourcegen
7 |
8 | import scala.reflect.internal.util.BatchSourceFile
9 |
10 | trait CommonPrintUtils {
11 |
12 | this: common.CompilerAccess with AbstractPrinter =>
13 |
14 | import global._
15 |
16 | def newline(implicit ctx: PrintingContext) = Requisite.newline("", ctx.newline)
17 |
18 | def indentedNewline(implicit ctx: PrintingContext) = Requisite.newline(ctx.ind.current, ctx.newline)
19 |
20 | def newlineIndentedToChildren(implicit ctx: PrintingContext) = Requisite.newline(ctx.ind.incrementDefault.current, ctx.newline)
21 |
22 | def indentation(implicit ctx: PrintingContext) = ctx.ind.current
23 |
24 | def typeToString(tree: TypeTree, t: Type)(implicit ctx: PrintingContext): String = {
25 | t match {
26 | case tpe if tpe == EmptyTree.tpe => ""
27 | case tpe: ConstantType =>
28 | tpe.typeSymbol.tpe.toString
29 | case tpe: TypeRef if tree.original != null && tpe.sym.nameString.matches("Tuple\\d+") =>
30 | tpe.toString
31 | case tpe if tree.original != null && !tpe.isInstanceOf[TypeRef] =>
32 | print(tree.original, ctx).asText
33 | case tpe: RefinedType =>
34 | tpe.typeSymbol.tpe.toString
35 | case typeRef @ TypeRef(_, _, arg1 :: ret :: Nil) if definitions.isFunctionType(typeRef) =>
36 | typeToString(tree, arg1) + " => " + typeToString(tree, ret)
37 | case MethodType(params, result) =>
38 | val printedParams = params.map(s => typeToString(tree, s.tpe)).mkString(", ")
39 | val printedResult = typeToString(tree, result)
40 |
41 | if (params.isEmpty) {
42 | "() => " + printedResult
43 | } else if (params.size > 1) {
44 | "(" + printedParams + ") => " + printedResult
45 | } else {
46 | printedParams + " => " + printedResult
47 | }
48 |
49 | case tpe =>
50 | tpe.toString
51 | }
52 | }
53 |
54 | def balanceBracketsInLayout(open: Char, close: Char, l: Layout) = {
55 | balanceBrackets(open, close)(Fragment(l.asText)).toLayout
56 | }
57 |
58 | def balanceBrackets(open: Char, close: Char)(f: Fragment) = Fragment {
59 | val (opening, closing) = SourceUtils.countRelevantBrackets(f.toLayout.asText, open, close)
60 | if (opening > closing) {
61 | f.asText + (("" + close) * (opening - closing))
62 | } else if (opening < closing) {
63 | (("" + open) * (closing - opening)) + f.asText
64 | } else {
65 | f.asText
66 | }
67 | }
68 |
69 | /**
70 | * When extracting source code from the file via a tree's position,
71 | * it depends on the tree type whether we can use the position's
72 | * start or point.
73 | *
74 | * @param t The tree that will be replaced.
75 | * @param p The position to adapt. This does not have to be the position of t.
76 | */
77 | def adjustedStartPosForSourceExtraction(t: Tree, p: Position): Position = t match {
78 | case _: Select | _: New if t.pos.isRange && t.pos.start > t.pos.point =>
79 | p withStart (p.start min p.point)
80 | case _ =>
81 | p
82 | }
83 |
84 | lazy val precedence: Name => Int = {
85 |
86 | // Copied from the compiler
87 | def newUnitParser(code: String) = new syntaxAnalyzer.UnitParser(newCompilationUnit(code))
88 | def newCompilationUnit(code: String) = new CompilationUnit(newSourceFile(code))
89 | def newSourceFile(code: String) = new BatchSourceFile("", code)
90 |
91 | val parser = newUnitParser("")
92 |
93 | // I ♥ Scala
94 | name => parser.precedence(newTermName(name.decode))
95 | }
96 |
97 | }
98 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/sourcegen/Formatting.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring.sourcegen
6 |
7 | import scala.util.Properties
8 |
9 | /**
10 | * Holds default formatting preferences.
11 | */
12 | trait Formatting {
13 |
14 | /**
15 | * The characters that are used to indent changed code.
16 | */
17 | def defaultIndentationStep = " "
18 |
19 | /**
20 | * The characters that surround an import with multiple
21 | * import selectors inside the braces:
22 | *
23 | * import a.{*name*}
24 | */
25 | def spacingAroundMultipleImports = ""
26 |
27 | /**
28 | * If set to `true` printer of import should drop `scala.` prefix:
29 | *
30 | * `import scala.util.Try` should be printed as
31 | *
32 | * `import util.Try`
33 | */
34 | def dropScalaPackage = false
35 |
36 | /** Used when new line is added to source file and EOL is needed. */
37 | def lineDelimiter = Properties.lineSeparator
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/sourcegen/Fragment.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package sourcegen
7 |
8 | trait Fragment {
9 | self =>
10 | val leading: Layout
11 | val center: Layout
12 | val trailing: Layout
13 |
14 | val pre = NoRequisite: Requisite
15 | val post = NoRequisite: Requisite
16 |
17 | override def toString() = asText
18 |
19 | def dropLeadingLayout: Fragment = new Fragment {
20 | val leading = NoLayout
21 | val center = self.center
22 | val trailing = self.trailing
23 | override val pre = self.pre
24 | override val post = self.post
25 | }
26 |
27 | def dropLeadingIndentation: Fragment = new Fragment {
28 | val leading = Layout(self.leading.asText.replaceFirst("""^\s*""", ""))
29 | val center = self.center
30 | val trailing = self.trailing
31 | override val pre = self.pre
32 | override val post = self.post
33 | }
34 |
35 | def dropTrailingLayout: Fragment = new Fragment {
36 | val leading = self.leading
37 | val center = self.center
38 | val trailing = NoLayout
39 | override val pre = self.pre
40 | override val post = NoRequisite
41 | }
42 |
43 | def isEmpty: Boolean = this match {
44 | case EmptyFragment => true
45 | case _ if asText == "" => true
46 | case _ => false
47 | }
48 |
49 | def ifNotEmpty(f: Fragment => Fragment): Fragment = this match {
50 | case EmptyFragment => EmptyFragment
51 | case _ if asText == "" => EmptyFragment
52 | case _ => f(this)
53 | }
54 |
55 | def toLayout: Layout = new Layout {
56 | def asText = self.asText
57 | }
58 |
59 | def asText: String = pre(NoLayout, leading).asText + center.asText + post(trailing, NoLayout).asText
60 |
61 | /**
62 | * Combines two fragments, makes sure that
63 | * Requisites are satisfied.
64 | *
65 | * Combining two fragments (a,b,c) and (d,e,f)
66 | * yields a fragment (a,bcde,f).
67 | */
68 | def ++ (o: Fragment): Fragment = o match {
69 | case EmptyFragment => this
70 | case _ => new Fragment {
71 | val leading = self.leading
72 | val center = self.center ++ (self.post ++ o.pre)(self.trailing, o.leading) ++ o.center
73 | val trailing = o.trailing
74 |
75 | override val pre = self.pre
76 | override val post = o.post
77 | }
78 | }
79 |
80 | /**
81 | * Combines a fragment with a layout, makes sure that
82 | * Requisites are satisfied.
83 | *
84 | * Combining (a,b,c) and (d)
85 | * yields a fragment (a,b,cd).
86 | */
87 | def ++ (o: Layout): Fragment = o match {
88 | case NoLayout => this
89 | case _ => new Fragment {
90 | val leading = self.leading
91 | val center = self.center
92 | val trailing = self.post(self.trailing, o)
93 |
94 | override val pre = self.pre
95 | override val post = /*if (self.post.isRequired(this.trailing, NoLayout)) self.post else*/ NoRequisite
96 | }
97 | }
98 |
99 | def ++ (after: Requisite, before: Requisite = NoRequisite): Fragment = {
100 | new Fragment {
101 | val leading = self.leading
102 | val center = self.center
103 | val trailing = self.trailing
104 |
105 | override val pre = before ++ self.pre
106 | override val post = self.post ++ after
107 | }
108 | }
109 | }
110 |
111 | abstract class EmptyFragment extends Fragment {
112 | val leading = NoLayout: Layout
113 | val center = NoLayout: Layout
114 | val trailing = NoLayout: Layout
115 | }
116 |
117 | object EmptyFragment extends EmptyFragment
118 |
119 | object Fragment {
120 |
121 | def unapply(f: Fragment) = Some((f.leading, f.center, f.trailing))
122 |
123 | def apply(l: Layout, c: Layout, t: Layout) = new Fragment {
124 | val leading = l
125 | val center = c
126 | val trailing = t
127 | }
128 |
129 | def apply(s: String) = new EmptyFragment {
130 | override val center = Layout(s)
131 | }
132 | }
133 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/sourcegen/Indentation.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package sourcegen
7 |
8 | /**
9 | * A class that handles indentation and is passed between
10 | * the pretty printer and the source generator.
11 | *
12 | * defaultIncrement specifies how much the indentation should
13 | * be incremented for newly generated code (pretty printer).
14 | */
15 | trait Indentations {
16 |
17 | this: common.Tracing =>
18 |
19 | class Indentation(val defaultIncrement: String, val current: String) {
20 |
21 | def incrementDefault = new Indentation(defaultIncrement, current + defaultIncrement)
22 |
23 | def setTo(i: String) = new Indentation(defaultIncrement, i)
24 |
25 | def needsToBeFixed(oldIndentation: String, surroundingLayout: Layout*) = {
26 | oldIndentation != current && surroundingLayout.exists(_.contains("\n"))
27 | }
28 |
29 | def fixIndentation(code: String, oldIndentation: String) = {
30 | trace("code is %s", code)
31 | trace("desired indentation is %s", current)
32 | trace("current indentation is %s", oldIndentation)
33 | Layout(code.replace("\n"+ oldIndentation, "\n"+ current))
34 | }
35 | }
36 |
37 | def indentationString(tree: scala.tools.nsc.Global#Tree): String = {
38 |
39 | def stripCommentFromSourceFile() = {
40 | if(memoizedSourceWithoutComments contains tree.pos.source.path) {
41 | memoizedSourceWithoutComments(tree.pos.source.path)
42 | } else {
43 | val src = SourceUtils.stripComment(tree.pos.source.content)
44 | memoizedSourceWithoutComments += tree.pos.source.path → src
45 | src
46 | }
47 | }
48 |
49 | var i = {
50 | if(tree.pos.start == tree.pos.source.length || tree.pos.source.content(tree.pos.start) == '\n' || tree.pos.source.content(tree.pos.start) == '\r')
51 | tree.pos.start - 1
52 | else
53 | tree.pos.start
54 | }
55 | val contentWithoutComment = stripCommentFromSourceFile()
56 |
57 | while(i >= 0 && contentWithoutComment(i) != '\n') {
58 | i -= 1
59 | }
60 |
61 | i += 1
62 |
63 | """\s*""".r.findFirstIn(contentWithoutComment.slice(i, tree.pos.start).mkString).getOrElse("")
64 | }
65 |
66 | private [this] val memoizedSourceWithoutComments = scala.collection.mutable.Map.empty[String, String]
67 |
68 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/sourcegen/Layout.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package sourcegen
7 |
8 | import scala.reflect.internal.util.SourceFile
9 |
10 | trait Layout {
11 | self =>
12 |
13 | def contains(s: String) = withoutComments.contains(s)
14 |
15 | def matches(r: String) = withoutComments.matches(r)
16 |
17 | /**
18 | * @return Returns this layout as a string but without comments.
19 | * Comments are replaced by whitespace.
20 | */
21 | lazy val withoutComments = SourceUtils.stripComment(asText)
22 |
23 | def asText: String
24 |
25 | override def toString() = asText
26 |
27 | def ++ (o: Layout) = o match {
28 | case NoLayout => this
29 | case _ => new Layout {
30 | override def asText = self.asText + o.asText
31 | }
32 | }
33 |
34 | def ++ (o: Fragment): Fragment = new Fragment {
35 | val leading = o.pre(self, o.leading)
36 | val center = o.center
37 | val trailing = o.trailing
38 |
39 | override val pre = if (o.pre.isRequired(this.leading, NoLayout)) o.pre else NoRequisite
40 | override val post = o.post
41 | }
42 |
43 | def ++ (r: Requisite): Fragment = new EmptyFragment {
44 | override val trailing = self
45 | override val post = r
46 | }
47 |
48 | def isEmpty: Boolean = asText.isEmpty
49 | def nonEmpty: Boolean = !isEmpty
50 | }
51 |
52 | case object NoLayout extends Layout {
53 | val asText = ""
54 | }
55 |
56 | object Layout {
57 |
58 | case class LayoutFromFile(source: SourceFile, start: Int, end: Int) extends Layout {
59 |
60 | lazy val asText = source.content.slice(start, end).mkString
61 |
62 | def splitAfter(cs: Char*): (Layout, Layout) = splitFromLeft(cs) match {
63 | case None => this → NoLayout
64 | case Some(i) => copy(end = i+1) → copy(start = i+1)
65 | }
66 |
67 | def splitAfterLast(cs: Char*): (Layout, Layout) = splitFromRight(cs) match {
68 | case None => this → NoLayout
69 | case Some(i) => copy(end = i+1) → copy(start = i+1)
70 | }
71 |
72 | def splitAtAndExclude(cs: Char*): (Layout, Layout) = splitFromLeft(cs) match {
73 | case None => this → NoLayout
74 | case Some(i) => copy(end = i) → copy(start = i+1)
75 | }
76 |
77 | def splitBefore(cs: Char*): (Layout, Layout) = splitFromLeft(cs) match {
78 | case None => NoLayout → this
79 | case Some(i) => copy(end = i) → copy(start = i)
80 | }
81 |
82 | private def splitFromLeft(cs: Seq[Char]): Option[Int] = {
83 | split(cs, c => withoutComments.indexOf(c.toInt))
84 | }
85 |
86 | private def splitFromRight(cs: Seq[Char]): Option[Int] = {
87 | split(cs, c => withoutComments.lastIndexOf(c.toInt))
88 | }
89 |
90 | private def split(cs: Seq[Char], findIndex: Char => Int): Option[Int] = cs.toList match {
91 | case Nil =>
92 | None
93 | case x :: xs =>
94 | val i = findIndex(x)
95 | if(i >= 0 ) {
96 | Some(start + i)
97 | } else
98 | split(xs, findIndex)
99 | }
100 | }
101 |
102 | def apply(source: SourceFile, start: Int, end: Int) = LayoutFromFile(source, start, end)
103 |
104 | def apply(s: String) = new Layout {
105 | val asText = s
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/sourcegen/Requisite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package sourcegen
7 |
8 | trait Requisite {
9 | self =>
10 |
11 | def isRequired(l: Layout, r: Layout): Boolean
12 |
13 | def apply(l: Layout, r: Layout): Layout = {
14 | if(isRequired(l, r)) {
15 | insertBetween(l, r)
16 | } else {
17 | l ++ r
18 | }
19 | }
20 |
21 | protected def insertBetween(l: Layout, r: Layout) = l ++ getLayout ++ r
22 |
23 | def getLayout: Layout
24 |
25 | def ++(other: Requisite): Requisite = (self, other) match {
26 | case (r, NoRequisite) => r
27 | case (NoRequisite, r) => r
28 | case _ => new Requisite {
29 | def isRequired(l: Layout, r: Layout) = self.isRequired(l, r) || other.isRequired(l, r)
30 | def getLayout = self.getLayout ++ other.getLayout
31 | override def apply(l: Layout, r: Layout) = {
32 | val _1 = if(self.isRequired(l, r)) self.getLayout else NoLayout
33 | val _2 = if(other.isRequired(l, r)) other.getLayout else NoLayout
34 | l ++ _1 ++ _2 ++ r
35 | }
36 | }
37 | }
38 | }
39 |
40 | object Requisite {
41 |
42 | def allowSurroundingWhitespace(req: String, toPrint: String): Requisite = {
43 |
44 | val regexSafeString = req flatMap {
45 | case '(' => "\\("
46 | case ')' => "\\)"
47 | case '{' => "\\{"
48 | case '}' => "\\}"
49 | case '[' => "\\["
50 | case ']' => "\\]"
51 | case c => c.toString
52 | }
53 |
54 | new Requisite {
55 | def isRequired(l: Layout, r: Layout) = {
56 | val isInLeft = l.matches("(?ms).*\\s*"+ regexSafeString +"\\s*$")
57 | val isInRight = r.matches("(?ms)^\\s*"+ regexSafeString + ".*")
58 | !isInLeft && !isInRight
59 | }
60 | def getLayout = Layout(toPrint)
61 | }
62 | }
63 |
64 | def allowSurroundingWhitespace(str: String): Requisite = {
65 | allowSurroundingWhitespace(str, str)
66 | }
67 |
68 | def anywhere(s: String): Requisite = anywhere(s, s)
69 |
70 | def anywhere(req: String, print: String): Requisite = new Requisite {
71 | def isRequired(l: Layout, r: Layout) = {
72 | !(l.contains(req) || r.contains(req))
73 | }
74 | def getLayout = Layout(print)
75 | }
76 |
77 | val Blank = new Requisite {
78 | def isRequired(l: Layout, r: Layout) = {
79 | val _1 = l.matches("(?s).*\\s+$")
80 | val _2 = r.matches("(?s)^\\s+.*")
81 |
82 | !(_1 || _2)
83 | }
84 | val getLayout = Layout(" ")
85 | }
86 |
87 | def newline(indentation: String, nl: String, force: Boolean = false) = new Requisite {
88 | def isRequired(l: Layout, r: Layout) = {
89 | val _1 = l.matches("(?ms).*\r?\n\\s*$")
90 | val _2 = r.matches("(?ms)^\\s*\r?\n.*")
91 | !(_1 || _2)
92 | }
93 | def getLayout = Layout(nl+ indentation)
94 | override def insertBetween(l: Layout, r: Layout) = {
95 | if(!force && r.asText.startsWith(indentation)) {
96 | l ++ Layout(nl) ++ r
97 | } else {
98 | l ++ getLayout ++ r
99 | }
100 | }
101 | }
102 | }
103 |
104 | object NoRequisite extends Requisite {
105 | def isRequired(l: Layout, r: Layout) = false
106 | val getLayout = NoLayout
107 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/transformation/TransformableSelections.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.transformation
2 |
3 | import scala.tools.refactoring.common.Selections
4 | import scala.tools.refactoring.common.CompilerAccess
5 |
6 | trait TransformableSelections extends Selections with TreeTransformations {
7 | self: CompilerAccess =>
8 |
9 | import global._
10 |
11 | implicit class TransformableSelection(selection: Selection) {
12 | def descendToEnclosingTreeAndThen(trans: Transformation[Tree, Tree]) =
13 | topdown {
14 | matchingChildren {
15 | predicate { (t: Tree) =>
16 | t.samePosAndType(selection.enclosingTree)
17 | } &> trans
18 | }
19 | }
20 |
21 | private def replaceSingleStatementBy(replacement: Tree) = {
22 | val original = selection.selectedTopLevelTrees.head
23 | transform {
24 | case t if t.samePosAndType(original) =>
25 | replacement replaces t
26 | case t =>
27 | t
28 | }
29 | }
30 |
31 | private def replaceSequenceBy(replacement: Tree, preserveHierarchy: Boolean) = {
32 | transform {
33 | case block @ Block(stats, expr) =>
34 | val allStats = (stats :+ expr)
35 | if (allStats.length == selection.selectedTopLevelTrees.length && !preserveHierarchy) {
36 | // only replace whole block if allowed to modify tree hierarchy
37 | replacement replaces block
38 | } else {
39 | val newStats = allStats.replaceSequencePreservingPositions(selection.selectedTopLevelTrees, replacement :: Nil)
40 | mkBlock(newStats) replaces block
41 | }
42 | }
43 | }
44 |
45 | /**
46 | * Replaces the selection by `replacement`.
47 | *
48 | * @param replacement
49 | * @param preserveHierarchy whether the original tree hierarchy must be preserved or
50 | * could be reduced if possible.
51 | * E.g. a selection contains all trees of the enclosing block:
52 | * - with `preserveHierarchy = true` the block will be replaced by `replacement`
53 | * - with `preserveHierarchy = false` the block will remain with `replacement`
54 | * as its only child tree
55 | */
56 | def replaceBy(replacement: Tree, preserveHierarchy: Boolean = false) = {
57 | descendToEnclosingTreeAndThen {
58 | if (selection.selectedTopLevelTrees.length == 1)
59 | replaceSingleStatementBy(replacement)
60 | else
61 | replaceSequenceBy(replacement, preserveHierarchy)
62 | }
63 | }
64 | }
65 | }
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/util/Memoized.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package util
7 |
8 | object Memoized {
9 |
10 | // Use this switch to temporarily turn of memoization for
11 | // debugging purposes:
12 | private val MemonizationEnabled = true
13 |
14 | /**
15 | *
16 | * Create a function that memoizes its results in a WeakHashMap.
17 | *
18 | * Note that memoization is tricky if the objects that
19 | * are memoized are mutable and the function being memoized
20 | * returns a value that somehow depends on this mutable property.
21 | *
22 | * For example, if we memoize a function that filters trees
23 | * based on their position, and later modify the position of
24 | * a tree, the memoized function will return the wrong value.
25 | *
26 | * So in order to make it safe, the mkKey function can be used
27 | * to provide a better key by including the mutable value.
28 | *
29 | * @param mkKey A function that creates the key.
30 | * @param toMem The function we want to memoize.
31 | */
32 | def on[X, Y, Z](mkKey: X => Y)(toMem: X => Z): X => Z = {
33 | if (!MemonizationEnabled) {
34 | toMem
35 | } else {
36 | val cache = new java.util.WeakHashMap[Y, Z]
37 |
38 | (x: X) => {
39 | val k = mkKey(x)
40 | if(cache.containsKey(k)) {
41 | val n = cache.get(k)
42 | if(n == null) {
43 | toMem(x)
44 | } else {
45 | n
46 | }
47 | } else {
48 | val n = toMem(x)
49 | cache.put(k, n)
50 | n
51 | }
52 | }
53 | }
54 | }
55 |
56 | def apply[X, Z](toMem: X => Z): X => Z = {
57 | if (!MemonizationEnabled) {
58 | toMem
59 | } else {
60 | val cache = new java.util.WeakHashMap[X, Z]
61 |
62 | (x: X) => {
63 | if(cache.containsKey(x)) {
64 | val n = cache.get(x)
65 | if(n == null) {
66 | toMem(x)
67 | } else {
68 | n
69 | }
70 | } else {
71 | val n = toMem(x)
72 | cache.put(x, n)
73 | n
74 | }
75 | }
76 | }
77 | }
78 | }
79 |
80 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/util/SourceHelpers.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.util
2 |
3 | import scala.annotation.tailrec
4 | import scala.math.min
5 | import scala.reflect.api.Position
6 | import scala.reflect.internal.util.SourceFile
7 | import scala.reflect.internal.util.RangePosition
8 |
9 | object SourceHelpers {
10 |
11 | /**
12 | * Decides whether a selection lies within a given text
13 | *
14 | * This is best explained by a few examples (selections are indicated by `[]`):
15 | *
16 | * - `isRangeWithin("a", "[a]") == true`
17 | * - `isRangeWithin("ab", "[a]ce") == false`
18 | * - `isRangeWithin("ab", "[a]bc") == true`
19 | * - `isRangeWithin("ab", "a[b]c") == true`
20 | * - `isRangeWithin("ab", "ab[c]") == false`
21 | *
22 | */
23 | def isRangeWithin(text: String, selection: SourceWithSelection): Boolean = {
24 | if (selection.length > text.length || selection.source.length < text.length) {
25 | false
26 | } else {
27 | val maxStepsBack = min(text.length - selection.length, selection.start)
28 |
29 | @tailrec
30 | def tryMatchText(stepsBack: Int = 0): Boolean = {
31 | if (stepsBack > maxStepsBack) {
32 | false
33 | } else {
34 | val start = selection.start - stepsBack
35 |
36 | @tailrec
37 | def matchSlice(i: Int = 0): Boolean = {
38 | if (i >= text.length) {
39 | true
40 | } else {
41 | if (text.charAt(i) != selection.source.charAt(start + i)) {
42 | false
43 | } else {
44 | matchSlice(i + 1)
45 | }
46 | }
47 | }
48 |
49 | if (start + text.length <= selection.source.length && matchSlice()) {
50 | true
51 | } else {
52 | tryMatchText(stepsBack + 1)
53 | }
54 | }
55 | }
56 |
57 | tryMatchText()
58 | }
59 | }
60 |
61 | def stringCoveredBy(pos: Position): Option[String] = {
62 | if (pos.isRange) Some(new String(pos.source.content.slice(pos.start, pos.end)))
63 | else None
64 | }
65 |
66 | def findComments(source: SourceFile, includeTrailingNewline: Boolean = true): List[RangePosition] = {
67 | import scala.tools.refactoring.util.SourceWithMarker
68 | import scala.tools.refactoring.util.SourceWithMarker.Movements
69 | import scala.tools.refactoring.util.SourceWithMarker.Movements.charToMovement
70 |
71 | val commentMvnt = {
72 | if (includeTrailingNewline) Movements.comment ~ '\r'.optional ~ '\n'.optional
73 | else Movements.comment
74 | }
75 |
76 | @tailrec
77 | def doWork(srcWithMarker: SourceWithMarker, acc: List[RangePosition] = Nil): List[RangePosition] = {
78 | if (srcWithMarker.isDepleted) {
79 | acc
80 | } else {
81 | srcWithMarker.applyMovement(commentMvnt) match {
82 | case Some(srcAfterComment) =>
83 | val (commentStart, commentEnd) = (srcWithMarker.marker, srcAfterComment.marker)
84 | val commentRange = new RangePosition(source, commentStart, commentStart, commentEnd)
85 | doWork(srcAfterComment, commentRange :: acc)
86 |
87 | case None =>
88 | doWork(srcWithMarker.stepForward, acc)
89 | }
90 | }
91 | }
92 |
93 | doWork(SourceWithMarker(source.content)).reverse
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/util/SourceWithSelection.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.util
2 |
3 | case class SourceWithSelection(source: IndexedSeq[Char], start: Int, end: Int) {
4 | require(start > -1 && end >= start)
5 | require(end <= source.length)
6 |
7 | def length = end - start
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/util/UnionFind.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.util
2 |
3 | import scala.collection.mutable.Map
4 | import scala.collection.mutable.HashMap
5 |
6 | /*
7 | * Implements a standard Union-Find (a.k.a Disjoint Set) data
8 | * structure with permissive behavior with respect to
9 | * non-existing elements in the structure (Unknown elements are
10 | * added as new elements when queried for).
11 | *
12 | * See Cormen, Thomas H.; Leiserson, Charles E.; Rivest, Ronald
13 | * L.; Stein, Clifford (2001), "Chapter 21: Data structures for
14 | * Disjoint Sets", Introduction to Algorithms (Second ed.), MIT
15 | * Press, pp. 498–524, ISBN 0-262-03293-7
16 | *
17 | * Amortized time for a sequence of m {union, find} operations
18 | * is O(m * InvAckermann(n)) where n is the number of elements
19 | * and InvAckermann is the inverse of the Ackermann function.
20 | *
21 | * Not thread-safe.
22 | */
23 |
24 | class UnionFind[T]() {
25 |
26 | private val parent: Map[T, T] = new HashMap[T,T] {
27 | override def default(s: T) = {
28 | get(s) match {
29 | case Some(v) => v
30 | case None => put(s, s); s
31 | }
32 | }
33 | }
34 |
35 | private val rank: Map[T, Int] = new HashMap[T,Int] {
36 | override def default(s: T) = {
37 | get(s) match {
38 | case Some(v) => v
39 | case None => put(s, 1); 1
40 | }
41 | }
42 | }
43 |
44 | /**
45 | * Return the parent (representant) of the equivalence class.
46 | * Uses path compression.
47 | */
48 | def find(s: T): T = {
49 | val ps = parent(s)
50 | if (ps == s) s else {
51 | val cs = find(ps)
52 | parent(s) = cs
53 | cs
54 | }
55 | }
56 |
57 | /**
58 | * Unify equivalence classes of elements.
59 | * Uses union by rank.
60 | */
61 | def union(x: T, y: T): Unit = {
62 | val cx = find(x)
63 | val cy = find(y)
64 | if (cx != cy) {
65 | val rx = rank(x)
66 | val ry = rank(y)
67 | if (rx > ry) parent(cy) = cx
68 | else if (rx < ry) parent(cx) = cy
69 | else {
70 | rank(cx) += 1
71 | parent(cy) = cx
72 | }
73 | }
74 | }
75 |
76 | /**
77 | * Enumerates the equivalence class of element x
78 | */
79 | def equivalenceClass(x: T): List[T] = {
80 | val px = parent(x)
81 | parent.keys filter (parent(_:T) == px) toList
82 | }
83 |
84 | }
85 |
--------------------------------------------------------------------------------
/src/main/scala/scala/tools/refactoring/util/UniqueNames.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.util
2 |
3 | import java.util.UUID
4 |
5 | object UniqueNames {
6 | def scalaFile(): String = {
7 | s"${basename()}.scala"
8 | }
9 |
10 | def srcDir(): String = {
11 | s"src-${uid()}"
12 | }
13 |
14 | def basename(): String = {
15 | uid()
16 | }
17 |
18 | def scalaPackage(): String = {
19 | uid()
20 | }
21 |
22 | private def uid(): String = {
23 | def longToName(l: Long) = {
24 | java.lang.Long.toString(l, Character.MAX_RADIX).replace("-", "_")
25 | }
26 |
27 | val rid = UUID.randomUUID()
28 | "uid" + longToName(rid.getLeastSignificantBits) + longToName(rid.getMostSignificantBits)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/src/test/java/scala/tools/refactoring/common/TracingHelpersTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.common
2 |
3 | import org.junit.Test
4 | import org.junit.Assert._
5 |
6 |
7 | class TracingHelpersTest {
8 | import TracingHelpers._
9 |
10 | @Test
11 | def compactifyWithShortMsgs(): Unit = {
12 | assertEquals("", compactify(""))
13 | assertEquals("xxx", compactify("xxx"))
14 | }
15 |
16 | @Test
17 | def compactifyWithMultilineMsg(): Unit = {
18 | val mlMsg = "1. Do\n2. Make\n3. Say\n4. Think"
19 | assertEquals("1. Do...(3 more lines ommitted)", compactify(mlMsg))
20 | }
21 |
22 | @Test
23 | def compactifyWithVeryLongMsg(): Unit = {
24 | val looooongMsg = "Loooooooonger, then eeeeeeeeeeever before!!!!!!!!!!!!!!!!!"
25 | val compactified = compactify(looooongMsg)
26 | assertTrue(compactified.length < looooongMsg.length)
27 | assertTrue(compactified.startsWith("Loooooooonger"))
28 | }
29 | }
--------------------------------------------------------------------------------
/src/test/java/scala/tools/refactoring/tests/util/ExceptionWrapper.java:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.util;
2 |
3 | import org.junit.rules.TestRule;
4 | import org.junit.runner.Description;
5 | import org.junit.runners.model.Statement;
6 |
7 | import scala.tools.nsc.util.FailedInterrupt;
8 |
9 | /**
10 | * In case an assertion error is caught and wrapped by another error type (as it
11 | * is the case for exceptions that are thrown on the compiler thread), we need
12 | * to manually unwrap them later, in order to let JUnit "see" them.
13 | */
14 | public final class ExceptionWrapper implements TestRule {
15 |
16 | @Override
17 | public Statement apply(final Statement base, final Description description) {
18 | return new Statement() {
19 | @Override
20 | public void evaluate() throws Throwable {
21 | try {
22 | base.evaluate();
23 | } catch (FailedInterrupt e) {
24 | // a FailedInterrupt is thrown when an exception occurs on the compiler thread
25 | throw e.getCause();
26 | }
27 | }
28 | };
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/src/test/java/scala/tools/refactoring/tests/util/ScalaVersion.java:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.util;
2 |
3 | import java.lang.annotation.Retention;
4 | import java.lang.annotation.RetentionPolicy;
5 |
6 | @Retention(RetentionPolicy.RUNTIME)
7 | public @interface ScalaVersion {
8 | String matches() default "";
9 | String doesNotMatch() default "";
10 | }
11 |
--------------------------------------------------------------------------------
/src/test/java/scala/tools/refactoring/tests/util/ScalaVersionTestRule.java:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.util;
2 |
3 | import org.junit.Assume;
4 | import org.junit.rules.MethodRule;
5 | import org.junit.runners.model.FrameworkMethod;
6 | import org.junit.runners.model.Statement;
7 |
8 | import scala.util.Properties;
9 |
10 | public class ScalaVersionTestRule implements MethodRule {
11 |
12 | final class EmptyStatement extends Statement {
13 | @Override
14 | public void evaluate() throws Throwable {
15 | Assume.assumeTrue(false);
16 | }
17 | }
18 |
19 | public Statement apply(Statement stmt, FrameworkMethod meth, Object arg2) {
20 | ScalaVersion onlyOn = meth.getAnnotation(ScalaVersion.class);
21 | String versionString = Properties.versionString();
22 |
23 | if (onlyOn != null) {
24 | if (!onlyOn.doesNotMatch().isEmpty() && versionString.contains(onlyOn.doesNotMatch())) {
25 | return new EmptyStatement();
26 | } else if (versionString.contains(onlyOn.matches())) {
27 | return stmt;
28 | } else {
29 | return new EmptyStatement();
30 | }
31 | } else {
32 | return stmt;
33 | }
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/src/test/java/scala/tools/refactoring/tests/util/TestRules.java:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.util;
2 |
3 | import org.junit.Rule;
4 |
5 | public abstract class TestRules {
6 |
7 | // all rules need to be public fields
8 |
9 | @Rule
10 | public final ScalaVersionTestRule rule1 = new ScalaVersionTestRule();
11 |
12 | @Rule
13 | public final ExceptionWrapper rule2 = new ExceptionWrapper();
14 | }
15 |
--------------------------------------------------------------------------------
/src/test/scala-2.10/README:
--------------------------------------------------------------------------------
1 | Put tests specifically for Scala-2.10.x here
--------------------------------------------------------------------------------
/src/test/scala-2.10/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsScalaSpecificTests.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.implementations.imports
2 |
3 | import scala.tools.refactoring.implementations.OrganizeImports
4 | import scala.tools.refactoring.implementations.OrganizeImports.Dependencies
5 |
6 | class OrganizeImportsScalaSpecificTests extends OrganizeImportsBaseTest {
7 | private def organizeCustomized(
8 | groupPkgs: List[String] = List("java", "scala", "org", "com"),
9 | useWildcards: Set[String] = Set("scalaz", "scalaz.Scalaz"),
10 | dependencies: Dependencies.Value,
11 | organizeLocalImports: Boolean = true)(pro: FileSet) = new OrganizeImportsRefatoring(pro) {
12 | val oiConfig = OrganizeImports.OrganizeImportsConfig(
13 | importsStrategy = Some(OrganizeImports.ImportsStrategy.ExpandImports),
14 | wildcards = useWildcards,
15 | groups = groupPkgs)
16 | val params = {
17 | new refactoring.RefactoringParameters(
18 | deps = dependencies,
19 | organizeLocalImports = organizeLocalImports,
20 | config = Some(oiConfig))
21 | }
22 | }.mkChanges
23 |
24 | @Ignore("Passes for scala 2.11 only. Implementation bases on 2.11 specific tree structure.")
25 | @Test
26 | def shouldNotRemoveImportWhenExtendedClassHasInferredTypeParam() = new FileSet {
27 | """
28 | package acme
29 |
30 | class Extended[T](val a: T, t: String)
31 | """ isNotModified
32 |
33 | """
34 | /*<-*/
35 | package acme.test
36 |
37 | import acme.Extended
38 |
39 | class Tested(val id: Int) extends Extended(id, "text")
40 | """ isNotModified
41 | } applyRefactoring organizeCustomized(dependencies = Dependencies.RecomputeAndModify)
42 | }
43 |
--------------------------------------------------------------------------------
/src/test/scala-2.10/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsWithMacrosTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.implementations.imports
2 |
3 | /** Refactoring for 2.10 does not support macros. Left unimplemented. */
4 | class OrganizeImportsWithMacrosTest extends OrganizeImportsBaseTest
5 |
--------------------------------------------------------------------------------
/src/test/scala-2.11/README:
--------------------------------------------------------------------------------
1 | Put tests specifically for Scala-2.11.x here
--------------------------------------------------------------------------------
/src/test/scala-2.11/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsScalaSpecificTests.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.implementations.imports
2 |
3 | import scala.tools.refactoring.implementations.OrganizeImports
4 | import scala.tools.refactoring.implementations.OrganizeImports.Dependencies
5 |
6 | class OrganizeImportsScalaSpecificTests extends OrganizeImportsBaseTest {
7 | private def organizeCustomized(
8 | groupPkgs: List[String] = List("java", "scala", "org", "com"),
9 | useWildcards: Set[String] = Set("scalaz", "scalaz.Scalaz"),
10 | dependencies: Dependencies.Value,
11 | organizeLocalImports: Boolean = true)(pro: FileSet) = new OrganizeImportsRefatoring(pro) {
12 | val oiConfig = OrganizeImports.OrganizeImportsConfig(
13 | importsStrategy = Some(OrganizeImports.ImportsStrategy.ExpandImports),
14 | wildcards = useWildcards,
15 | groups = groupPkgs)
16 | val params = {
17 | new refactoring.RefactoringParameters(
18 | deps = dependencies,
19 | organizeLocalImports = organizeLocalImports,
20 | config = Some(oiConfig))
21 | }
22 | }.mkChanges
23 |
24 | @Test
25 | def shouldNotRemoveImportWhenExtendedClassHasInferredTypeParam() = new FileSet {
26 | """
27 | package acme
28 |
29 | class Extended[T](val a: T, t: String)
30 | """ isNotModified
31 |
32 | """
33 | /*<-*/
34 | package acme.test
35 |
36 | import acme.Extended
37 |
38 | class Tested(val id: Int) extends Extended(id, "text")
39 | """ isNotModified
40 | } applyRefactoring organizeCustomized(dependencies = Dependencies.RecomputeAndModify)
41 |
42 | @Test
43 | def shouldNotRemoveImportWhenJustPackage_v3() = new FileSet {
44 | """
45 | /*<-*/
46 | package tested
47 |
48 | import scala.reflect.macros.whitebox
49 |
50 | class Tested(val c: whitebox.Context)
51 | """ isNotModified
52 | } applyRefactoring organizeCustomized(dependencies = Dependencies.RecomputeAndModify)
53 | }
54 |
--------------------------------------------------------------------------------
/src/test/scala-2.12/README:
--------------------------------------------------------------------------------
1 | Put tests specifically for Scala-2.11.x here
--------------------------------------------------------------------------------
/src/test/scala-2.12/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsScalaSpecificTests.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.implementations.imports
2 |
3 | import scala.tools.refactoring.implementations.OrganizeImports
4 | import scala.tools.refactoring.implementations.OrganizeImports.Dependencies
5 |
6 | class OrganizeImportsScalaSpecificTests extends OrganizeImportsBaseTest {
7 | private def organizeCustomized(
8 | groupPkgs: List[String] = List("java", "scala", "org", "com"),
9 | useWildcards: Set[String] = Set("scalaz", "scalaz.Scalaz"),
10 | dependencies: Dependencies.Value,
11 | organizeLocalImports: Boolean = true)(pro: FileSet) = new OrganizeImportsRefatoring(pro) {
12 | val oiConfig = OrganizeImports.OrganizeImportsConfig(
13 | importsStrategy = Some(OrganizeImports.ImportsStrategy.ExpandImports),
14 | wildcards = useWildcards,
15 | groups = groupPkgs)
16 | val params = {
17 | new refactoring.RefactoringParameters(
18 | deps = dependencies,
19 | organizeLocalImports = organizeLocalImports,
20 | config = Some(oiConfig))
21 | }
22 | }.mkChanges
23 |
24 | @Test
25 | def shouldNotRemoveImportWhenExtendedClassHasInferredTypeParam() = new FileSet {
26 | """
27 | package acme
28 |
29 | class Extended[T](val a: T, t: String)
30 | """ isNotModified
31 |
32 | """
33 | /*<-*/
34 | package acme.test
35 |
36 | import acme.Extended
37 |
38 | class Tested(val id: Int) extends Extended(id, "text")
39 | """ isNotModified
40 | } applyRefactoring organizeCustomized(dependencies = Dependencies.RecomputeAndModify)
41 |
42 | @Test
43 | def shouldNotRemoveImportWhenJustPackage_v3() = new FileSet {
44 | """
45 | /*<-*/
46 | package tested
47 |
48 | import scala.reflect.macros.whitebox
49 |
50 | class Tested(val c: whitebox.Context)
51 | """ isNotModified
52 | } applyRefactoring organizeCustomized(dependencies = Dependencies.RecomputeAndModify)
53 | }
54 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/implementations/OrganizeImportsAlgosTest.scala:
--------------------------------------------------------------------------------
1 |
2 |
3 | package scala.tools.refactoring.implementations
4 |
5 | import org.junit.Test
6 | import org.junit.Assert._
7 |
8 | class OrganizeImportsAlgosTest {
9 | import OrganizeImports.Algos
10 |
11 | @Test
12 | def testGroupImportsWithTrivialExamples(): Unit = {
13 | testGroupImports(Nil, Nil, Nil)
14 | testGroupImports(List("org", "com"), Nil, Nil)
15 | testGroupImports(Nil, List("com.github.Lausbub"), List(List("com.github.Lausbub")))
16 | }
17 |
18 | @Test
19 | def testGroupImportsWithSimpleExamples(): Unit = {
20 | testGroupImports(
21 | groups = List("org"),
22 | imports = List("org.junit.Test", "org.junit.Assert._", "language.postfixOps"),
23 | expected = List(List("org.junit.Test", "org.junit.Assert._"), List("language.postfixOps")))
24 |
25 | testGroupImports(
26 | groups = List("java", "scala", "org"),
27 | imports = List("java.lang.String", "java.lang.Long", "org.junit.Before", "scala.Option.option2Iterable", "language.implicitConversions"),
28 | expected = List(List("java.lang.String", "java.lang.Long"), List("scala.Option.option2Iterable"), List("org.junit.Before"), List("language.implicitConversions")))
29 | }
30 |
31 | @Test
32 | def testGroupImportsWithNastyExamples(): Unit = {
33 | testGroupImports(
34 | groups = List("a.c", "a"),
35 | imports = List("a.b.X", "a.c.Y"),
36 | expected = List(List("a.c.Y"), List("a.b.X")))
37 |
38 | testGroupImports(
39 | groups = List("a", "a.c"),
40 | imports = List("a.b.X", "a.c.Y"),
41 | expected = List(List("a.b.X"), List("a.c.Y")))
42 |
43 | testGroupImports(
44 | groups = List("a", "a.b", "ab"),
45 | imports = List("a.A", "a.b.AB", "ab.Ab1", "ab.Ab2", "abc.Abc1", "abc.Abc2"),
46 | expected = List(List("a.A"), List("a.b.AB"), List("ab.Ab1", "ab.Ab2"), List("abc.Abc1", "abc.Abc2")))
47 | }
48 |
49 | @Test
50 | def testGroupImportsWithAlternatives(): Unit = {
51 | testGroupImports(
52 | groups = List("a.c", "a,b.c", "b"),
53 | imports = List("a.b.X", "a.c.Y", "b.a.Y", "b.c.X"),
54 | expected = List(List("a.c.Y"), List("a.b.X", "b.c.X"), List("b.a.Y")))
55 |
56 | testGroupImports(
57 | groups = List("a.c", "nonexisting", "a,b.c", "b"),
58 | imports = List("a.b.X", "a.c.Y", "b.a.Y", "b.c.X"),
59 | expected = List(List("a.c.Y"), List("a.b.X", "b.c.X"), List("b.a.Y")))
60 |
61 | testGroupImports(
62 | groups = List("a.c", "*", "a,b.c", "b"),
63 | imports = List("a.b.X", "a.c.Y", "b.a.Y", "b.c.X", "d.a.Y"),
64 | expected = List(List("a.c.Y"), List("d.a.Y"), List("a.b.X", "b.c.X"), List("b.a.Y")))
65 |
66 | // default group at the end
67 | testGroupImports(
68 | groups = List("a.c", "a,b.c", "b"),
69 | imports = List("a.b.X", "a.c.Y", "b.a.Y", "b.c.X", "d.a.Y"),
70 | expected = List(List("a.c.Y"), List("a.b.X", "b.c.X"), List("b.a.Y"), List("d.a.Y")))
71 |
72 | // disordered alternative
73 | testGroupImports(
74 | groups = List("a.c", "b.c,a", "b"),
75 | imports = List("a.b.X", "a.c.Y", "b.a.Y", "b.c.X"),
76 | expected = List(List("a.c.Y"), List("a.b.X", "b.c.X"), List("b.a.Y")))
77 |
78 | // disordered and repeated alternative
79 | testGroupImports(
80 | groups = List("a.c", "b.c,a", "b", "a,b.c"),
81 | imports = List("a.b.X", "a.c.Y", "b.a.Y", "b.c.X"),
82 | expected = List(List("a.c.Y"), List("a.b.X", "b.c.X"), List("b.a.Y")))
83 | }
84 |
85 | private def testGroupImports(groups: List[String], imports: List[String], expected: List[List[String]]): Unit = {
86 | def getImportExpr(imp: String) = {
87 | val lastDot = imp.lastIndexOf('.')
88 | assert(lastDot >= 0)
89 | imp.substring(0, lastDot)
90 | }
91 |
92 | val actual = Algos.groupImports(getImportExpr)(groups, imports)
93 | assertEquals(expected, actual)
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/RefactoringTestSuite.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package tests
3 |
4 | import org.junit.runner.RunWith
5 | import org.junit.runners.Suite
6 | import analysis._
7 | import common._
8 | import implementations._
9 | import implementations.imports._
10 | import sourcegen._
11 | import transformation._
12 | import util._
13 | import scala.tools.refactoring.implementations.OrganizeImportsAlgosTest
14 | import scala.tools.refactoring.common.TracingHelpersTest
15 |
16 | @RunWith(value = classOf[Suite])
17 | @Suite.SuiteClasses(value = Array(
18 | classOf[AddFieldTest],
19 | classOf[AddImportStatementTest],
20 | classOf[AddMethodTest],
21 | classOf[ChangeParamOrderTest],
22 | classOf[CompilationUnitDependenciesTest],
23 | classOf[CustomFormattingTest],
24 | classOf[DeclarationIndexTest],
25 | classOf[ExpandCaseClassBindingTest],
26 | classOf[ExplicitGettersSettersTest],
27 | classOf[extraction.ExtractCodeTest],
28 | classOf[extraction.ExtractExtractorTest],
29 | classOf[extraction.ExtractionsTest],
30 | classOf[extraction.ExtractMethodTest],
31 | classOf[extraction.ExtractParameterTest],
32 | classOf[extraction.ExtractValueTest],
33 | classOf[ExtractLocalTest],
34 | classOf[ExtractMethodTest],
35 | classOf[ExtractTraitTest],
36 | classOf[FindShadowedTest],
37 | classOf[GenerateHashcodeAndEqualsTest],
38 | classOf[ImportAnalysisTest],
39 | classOf[IndividualSourceGenTest],
40 | classOf[InlineLocalTest],
41 | classOf[InsertionPositionsTest],
42 | classOf[IntroduceProductNTraitTest],
43 | classOf[LayoutTest],
44 | classOf[MarkOccurrencesTest],
45 | classOf[MergeParameterListsTest],
46 | classOf[MoveClassTest],
47 | classOf[MoveConstructorToCompanionObjectTest],
48 | classOf[MultipleFilesIndexTest],
49 | classOf[NameValidationTest],
50 | classOf[OrganizeImportsCollapseSelectorsToWildcardTest],
51 | classOf[OrganizeImportsFullyRecomputeTest],
52 | classOf[OrganizeImportsGroupsTest],
53 | classOf[OrganizeImportsRecomputeAndModifyTest],
54 | classOf[OrganizeImportsTest],
55 | classOf[OrganizeImportsWildcardsTest],
56 | classOf[OrganizeMissingImportsTest],
57 | classOf[EnrichedTreesTest],
58 | classOf[PrependOrDropScalaPackageFromRecomputedTest],
59 | classOf[PrependOrDropScalaPackageKeepTest],
60 | classOf[PrettyPrinterTest],
61 | classOf[RenameTest],
62 | classOf[ReusingPrinterTest],
63 | classOf[ScopeAnalysisTest],
64 | classOf[SelectionDependenciesTest],
65 | classOf[SelectionExpansionsTest],
66 | classOf[SelectionPropertiesTest],
67 | classOf[SelectionsTest],
68 | classOf[SourceGenTest],
69 | classOf[SourceHelperTest],
70 | classOf[SplitParameterListsTest],
71 | classOf[TransformableSelectionTest],
72 | classOf[TreeAnalysisTest],
73 | classOf[TreeChangesDiscovererTest],
74 | classOf[TreeTransformationsTest],
75 | classOf[UnionFindInitTest],
76 | classOf[UnionFindTest],
77 | classOf[UnusedImportsFinderTest],
78 | classOf[SourceWithMarkerTest],
79 | classOf[OrganizeImportsAlgosTest],
80 | classOf[TracingHelpersTest],
81 | classOf[OrganizeImportsWithMacrosTest],
82 | classOf[OrganizeImportsScalaSpecificTests],
83 | classOf[OrganizeImportsEndOfLineTest]))
84 | class RefactoringTestSuite {}
85 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/analysis/FindShadowedTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package tests.analysis
7 |
8 | import tests.util.TestHelper
9 | import org.junit.Assert._
10 |
11 | class FindShadowedTest extends TestHelper {
12 |
13 | import global._
14 |
15 | @Ignore
16 | @Test
17 | def findSimpleShadowing(): Unit = {
18 |
19 | val t = treeFrom("""
20 | package shadowing
21 |
22 | object TheShadow {
23 | val i = 1
24 |
25 | def method: Unit = {
26 | val i = ""
27 | ()
28 | }
29 | }
30 |
31 | class Xyz(xyzxyz: Long) {
32 |
33 | def method: Unit = {
34 | val xyzxyz = ""
35 | val i = xyzxyz
36 | ()
37 | }
38 | }
39 |
40 | class Z {
41 | import TheShadow._
42 |
43 | val i = Nil
44 | }""")
45 |
46 | val results = new collection.mutable.ListBuffer[Symbol]
47 |
48 | t foreach {
49 | case v @ ValDef(_, name, _, _) =>
50 |
51 | val members = {
52 |
53 | def contexts(n: Context): Stream[Context] = n #:: contexts(n.outer)
54 |
55 | val context = global.doLocateContext(v.pos).outer
56 |
57 | val fromEnclosingScopes = contexts(context).takeWhile(_ != NoContext) flatMap {
58 | ctx => ctx.scope ++ (if (ctx == ctx.enclClass) ctx.prefix.members else Nil)
59 | }
60 |
61 | val fromImported = (context.imports flatMap (_.allImportedSymbols))
62 |
63 | fromEnclosingScopes ++ fromImported
64 | }
65 |
66 | results ++= members.find(s => s.name == name && s.pos != v.symbol.pos)
67 |
68 | case _ => Nil
69 | }
70 |
71 | assertEquals(3, results.size)
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/analysis/ImportAnalysisTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.analysis
2 |
3 | import scala.tools.refactoring.tests.util.TestHelper
4 | import scala.tools.refactoring.analysis.ImportAnalysis
5 | import org.junit.Assert._
6 |
7 | class ImportAnalysisTest extends TestHelper with ImportAnalysis {
8 | import global._
9 |
10 | @Test
11 | def importTrees() = global.ask { () =>
12 | val tree = treeFrom("""
13 | import scala.collection.immutable.{LinearSeq, _}
14 |
15 | object O{
16 | import scala.math._
17 | val a = 1
18 | }
19 |
20 | object P{
21 | import collection.mutable._
22 | import O.a
23 | }
24 | """)
25 |
26 | val it = buildImportTree(tree)
27 | assertEquals("{scala.collection.immutable.List{scala.`package`._{scala.Predef._{scala.collection.immutable.LinearSeq{scala.collection.immutable._{scala.math._{}, scala.collection.mutable._{O.a{}}}}}}}}", it.toString())
28 | }
29 |
30 | @Test
31 | def isImported() = global.ask { () =>
32 | val tree = treeFrom("""
33 | import scala.math.E
34 |
35 | object O{
36 | def fn = {
37 | import scala.math.Pi
38 | 99 * Pi * E
39 | }
40 | }
41 | """)
42 |
43 | val it = buildImportTree(tree)
44 | val piRef = findSymTree(tree, "value Pi")
45 | val eRef = findSymTree(tree, "value E")
46 | val fnDef = findSymTree(tree, "method fn")
47 |
48 | assertTrue(it.isImportedAt(piRef.symbol, piRef.pos))
49 | assertTrue(it.isImportedAt(eRef.symbol, eRef.pos))
50 |
51 | assertFalse(it.isImportedAt(piRef.symbol, fnDef.pos))
52 | assertTrue(it.isImportedAt(eRef.symbol, fnDef.pos))
53 | }
54 |
55 | @Test
56 | def isImportedWithWildcard() = global.ask { () =>
57 | val tree = treeFrom("""
58 | object O{
59 | def fn = {
60 | import scala.math._
61 | 99 * Pi * E
62 | }
63 | }
64 | """)
65 |
66 | val it = buildImportTree(tree)
67 | val piRef = findSymTree(tree, "value Pi")
68 | val eRef = findSymTree(tree, "value E")
69 | val fnDef = findSymTree(tree, "method fn")
70 |
71 | assertTrue(it.isImportedAt(piRef.symbol, piRef.pos))
72 | assertTrue(it.isImportedAt(eRef.symbol, eRef.pos))
73 |
74 | assertFalse(it.isImportedAt(piRef.symbol, fnDef.pos))
75 | }
76 |
77 | @Test
78 | def predefsAreAlwaysImported() = global.ask { () =>
79 | val tree = treeFrom("""
80 | object O{
81 | println(123)
82 | List(1, 2, 3)
83 | }
84 | """)
85 |
86 | val it = buildImportTree(tree)
87 | val printlnRef = findSymTree(tree, "method println")
88 | val listRef = findSymTree(tree, "object List")
89 |
90 | val oDef = findSymTree(tree, "object O")
91 |
92 | assertTrue(it.isImportedAt(printlnRef.symbol, oDef.pos))
93 | assertTrue(it.isImportedAt(listRef.symbol, oDef.pos))
94 | }
95 |
96 | @Test
97 | @Ignore
98 | def importsOfValueMembers() = global.ask { () =>
99 | val tree = treeFrom("""
100 | package pkg
101 | object O{
102 | val a = new{ val b = 1 }
103 |
104 | import a._
105 |
106 | def fn = b
107 | }
108 | """)
109 |
110 | val it = buildImportTree(tree)
111 | val bRef = findSymTree(findSymTree(tree, "method fn"), "value b")
112 |
113 | val oDef = findSymTree(tree, "object O")
114 | val aDef = findSymTree(tree, "value a")
115 |
116 | assertFalse(it.isImportedAt(bRef.symbol, oDef.pos))
117 | assertFalse(it.isImportedAt(bRef.symbol, aDef.pos))
118 |
119 | assertTrue(it.isImportedAt(bRef.symbol, bRef.pos))
120 | }
121 |
122 | def findSymTree(t: Tree,s: String) =
123 | t.collect{
124 | case t: SymTree if t.symbol.toString == s => t
125 | }.head
126 | }
127 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/analysis/TreeAnalysisTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package tests.analysis
7 |
8 | import tests.util.TestHelper
9 | import org.junit.Assert._
10 | import analysis.GlobalIndexes
11 | import analysis.TreeAnalysis
12 |
13 | class TreeAnalysisTest extends TestHelper with GlobalIndexes with TreeAnalysis {
14 |
15 | import global._
16 |
17 | var index: IndexLookup = null
18 |
19 | def withIndex(src: String)(body: Tree => Unit ): Unit = {
20 | val tree = treeFrom(src)
21 | global.ask { () =>
22 | index = GlobalIndex(List(CompilationUnitIndex(tree)))
23 | }
24 | body(tree)
25 | }
26 |
27 | def assertInboundLocalDependencies(expected: String, src: String) = withIndex(src) { tree =>
28 |
29 | val selection = findMarkedNodes(src, tree)
30 | val in = global.ask(() => inboundLocalDependencies(selection, selection.selectedSymbols.head.owner))
31 | assertEquals(expected, in mkString ", ")
32 | }
33 |
34 | def assertOutboundLocalDependencies(expected: String, src: String) = withIndex(src) { tree =>
35 |
36 | val selection = findMarkedNodes(src, tree)
37 | val out = global.ask(() => outboundLocalDependencies(selection))
38 | assertEquals(expected, out mkString ", ")
39 | }
40 |
41 | @Test
42 | def findInboudLocalAndParameter() = {
43 |
44 | assertInboundLocalDependencies("value i, value a", """
45 | class A9 {
46 | def addThree(i: Int) = {
47 | val a = 1
48 | /*(*/ val b = a + 1 + i /*)*/
49 | val c = b + 1
50 | c
51 | }
52 | }
53 | """)
54 | }
55 |
56 | @Test
57 | def findParameterDependency() = {
58 |
59 | assertInboundLocalDependencies("value i", """
60 | class A8 {
61 | def addThree(i: Int) = {
62 | val a = 1
63 | /*(*/ val b = for(x <- 0 to i) yield x /*)*/
64 | "done"
65 | }
66 | }
67 | """)
68 | }
69 |
70 | @Test
71 | def findNoDependency() = {
72 |
73 | assertInboundLocalDependencies("", """
74 | class A7 {
75 | def addThree(i: Int) = {
76 | val a = 1
77 | /*(*/ val b = 2 * 21 /*)*/
78 | b
79 | }
80 | }
81 | """)
82 | }
83 |
84 | @Test
85 | def findDependencyOnMethod() = {
86 |
87 | assertInboundLocalDependencies("value i, method inc", """
88 | class A6 {
89 | def addThree(i: Int) = {
90 | def inc(j: Int) = j + 1
91 | /*(*/ val b = inc(inc(inc(i))) /*)*/
92 | b
93 | }
94 | }
95 | """)
96 | }
97 |
98 | @Test
99 | def findOutboundDeclarations() = {
100 |
101 | assertOutboundLocalDependencies("value b", """
102 | class A5 {
103 | def addThree = {
104 | /*(*/ val b = 1 /*)*/
105 | b + b + b
106 | }
107 | }
108 | """)
109 | }
110 |
111 | @Test
112 | def multipleReturnValues() = {
113 |
114 | assertOutboundLocalDependencies("value a, value b, value c", """
115 | class TreeAnalysisTest {
116 | def addThree = {
117 | /*(*/ val a = 'a'
118 | val b = 'b'
119 | val c = 'c'/*)*/
120 | a + b + c
121 | }
122 | }
123 | """)
124 | }
125 |
126 | @Test
127 | def dontReturnArgument() = {
128 |
129 | assertOutboundLocalDependencies("", """
130 | class TreeAnalysisTest {
131 | def go = {
132 | var a = 1
133 | /*(*/ a = 2 /*)*/
134 | a
135 | }
136 | }
137 | """)
138 | }
139 |
140 | @Test
141 | def findOnClassLevel() = {
142 |
143 | assertInboundLocalDependencies("", """
144 | class Outer {
145 | class B2 {
146 | val a = 1
147 | /*(*/ val b = a + 1 /*)*/
148 |
149 | def addThree(i: Int) = {
150 | val a = 1
151 | val b = 2 * 21
152 | b
153 | }
154 | }
155 | }
156 | """)
157 | }
158 | }
159 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/common/EnrichedTreesTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package tests.common
7 |
8 | import tests.util.TestHelper
9 | import org.junit.Assert._
10 | import common.EnrichedTrees
11 |
12 | class EnrichedTreesTest extends TestHelper with EnrichedTrees {
13 |
14 | import global._
15 |
16 | def tree = treeFrom("""
17 | package treetest
18 |
19 | class Test {
20 | val test = 42
21 | val test2 = 42
22 | }
23 |
24 | """)
25 |
26 | @Test
27 | def classHasNoRightSibling() = global.ask { () =>
28 |
29 | val c = tree.find(_.isInstanceOf[ClassDef]).get
30 |
31 | assertFalse(originalRightSibling(c).isDefined)
32 | assertTrue(originalLeftSibling(c).isDefined)
33 | }
34 |
35 | @Test
36 | def templateNoSiblings() = global.ask { () =>
37 |
38 | val c = tree.find(_.isInstanceOf[Template]).get
39 |
40 | assertTrue(originalLeftSibling(c).isDefined)
41 | assertFalse(originalRightSibling(c).isDefined)
42 | }
43 |
44 | @Test
45 | def parentChain() = global.ask { () =>
46 |
47 | val i = tree.find(_.toString == "42").get
48 |
49 | val root = originalParentOf(i) flatMap (originalParentOf(_) flatMap (originalParentOf(_) flatMap originalParentOf))
50 |
51 | assertTrue(root.get.isInstanceOf[PackageDef])
52 | }
53 |
54 | @Test
55 | def rootHasNoParent() = global.ask { () =>
56 | assertEquals(None, originalParentOf(tree))
57 | }
58 |
59 | @Test
60 | def testSiblings() = global.ask { () =>
61 |
62 | val v = tree.find(_.isInstanceOf[ValDef]).get
63 | val actual = originalParentOf(v).get.toString.replaceAll("\r\n", "\n")
64 |
65 | assertTrue(actual.contains("private[this] val test: Int = 42;"))
66 |
67 | assertEquals(None, originalLeftSibling(v))
68 | assertEquals("private[this] val test2: Int = 42", originalRightSibling(v).get.toString)
69 | }
70 |
71 | @Test
72 | def namePositionOfFieldAccessor() = {
73 | val src = """
74 | object O{
75 | val field = 1
76 | }
77 | """
78 | val root = treeFrom(src)
79 |
80 | val accessor = root.collect{
81 | case t: ValOrDefDef if t.name.decode == "field" => t
82 | }.head
83 |
84 | assertEquals(src.indexOf("field"), accessor.namePosition().point)
85 | }
86 | }
87 |
88 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/common/OccurrencesTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.common
2 |
3 | import scala.tools.refactoring.tests.util.TestHelper
4 | import scala.tools.refactoring.common.Occurrences
5 | import org.junit.Assert._
6 | import scala.tools.refactoring.analysis.GlobalIndexes
7 |
8 | class OccurrencesTest extends TestHelper with GlobalIndexes with Occurrences {
9 | import global._
10 |
11 | var index: IndexLookup = null
12 |
13 | def withIndex(src: String)(body: Tree => Unit): Unit = {
14 | val tree = treeFrom(src)
15 | global.ask { () =>
16 | index = GlobalIndex(List(CompilationUnitIndex(tree)))
17 | }
18 | body(tree)
19 | }
20 |
21 | @Test
22 | def termOccurrences() = {
23 | val src = """
24 | object O{
25 | def fn = {
26 | val a = 1
27 | val b = {
28 | val a = 2
29 | a
30 | }
31 | a * a
32 | }
33 | }
34 | """
35 | withIndex(src) { root =>
36 | val os = global.ask { () => termNameOccurrences(root, "a") }
37 | assertEquals(src.indexOf("a = 1"), os.head._1)
38 | assertTrue(os.forall(_._2 == "a".length))
39 | assertEquals(3, os.length)
40 | }
41 | }
42 |
43 | @Test
44 | def accessorNameOccurrences() = {
45 | val src = """
46 | object O{
47 | val field = 1
48 |
49 | def fn = 2 * field
50 | }
51 | """
52 | withIndex(src) { root =>
53 | val os = global.ask { () => termNameOccurrences(root, "field") }
54 | assertEquals(src.indexOf("field"), os.head._1)
55 | assertTrue(os.forall(_._2 == "field".length))
56 | assertEquals(2, os.length)
57 | }
58 | }
59 |
60 | @Test
61 | def paramsOccurrences() = {
62 | val src = """
63 | object O{
64 | def fn(a: Int, b: Int) = {
65 | val c = a * b
66 | a * b * c
67 | }
68 | }
69 | """
70 | withIndex(src) { root =>
71 | val os = global.ask { () => defDefParameterOccurrences(root, "fn") }
72 | val aos = os(0)
73 | val bos = os(1)
74 | assertEquals(src.indexOf("a: Int"), aos.head._1)
75 | assertEquals(src.indexOf("b: Int"), bos.head._1)
76 | assertTrue((aos ::: bos).forall(_._2 == 1))
77 | // two params
78 | assertEquals(2, os.length)
79 | // each name occurs three times
80 | assertEquals(aos.length, 3)
81 | assertEquals(bos.length, 3)
82 | }
83 | }
84 | }
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/common/SelectionPropertiesTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.common
2 |
3 | import scala.tools.refactoring.common.Selections
4 | import scala.tools.refactoring.tests.util.TestHelper
5 |
6 | import org.junit.Assert._
7 | import scala.tools.refactoring.tests.util.TextSelections
8 |
9 | class SelectionPropertiesTest extends TestHelper with Selections {
10 |
11 | implicit class StringToSel(src: String) {
12 | val root = treeFrom(src)
13 | val selection = {
14 | val textSelection = TextSelections.extractOne(src)
15 | FileSelection(root.pos.source.file, root, textSelection.from, textSelection.to)
16 | }
17 | }
18 |
19 | @Test
20 | def representsValue() = global.ask { () =>
21 | val sel = """
22 | object O{
23 | def fn = {
24 | /*(*/val i = 100
25 | i * 2/*)*/
26 | }
27 | }
28 | """.selection
29 | assertTrue(sel.representsValue)
30 | }
31 |
32 | @Test
33 | def doesNotRepresentValue() = global.ask { () =>
34 | val sel = """
35 | object O{
36 | def fn = {
37 | /*(*/val i = 100
38 | val b = i * 2/*)*/
39 | }
40 | }
41 | """.selection
42 | assertFalse(sel.representsValue)
43 | }
44 |
45 | @Test
46 | def nonValuePatternsDoNotRepresentValues() = global.ask { () =>
47 | val selWildcard = """object O { 1 match { case /*(*/_/*)*/ => () } }""".selection
48 | assertFalse(selWildcard.representsValue)
49 |
50 | val selCtorPattern = """object O { Some(1) match { case /*(*/Some(i)/*)*/ => () } }""".selection
51 | assertFalse(selCtorPattern.representsValue)
52 |
53 | val selBinding = """object O { 1 match { case /*(*/i: Int/*)*/ => i } }""".selection
54 | assertFalse(selBinding.representsValue)
55 |
56 | val selPatAndGuad = """object O { 1 match { case /*(*/i if i > 10/*)*/ => i } }""".selection
57 | assertFalse(selPatAndGuad.representsValue)
58 | }
59 |
60 | @Test
61 | def valuePatternsDoRepresentValues() = global.ask { () =>
62 | val selCtorPattern = """object O { Some(1) match { case /*(*/Some(1)/*)*/ => () } }""".selection
63 | assertTrue(selCtorPattern.representsValue)
64 | }
65 |
66 | @Test
67 | def argumentLists() = global.ask { () =>
68 | val sel = """
69 | object O{
70 | def fn = {
71 | List(/*(*/1, 2/*)*/, 3)
72 | }
73 | }
74 | """.selection
75 | assertFalse(sel.representsValue)
76 | assertFalse(sel.representsValueDefinitions)
77 | assertTrue(sel.representsArgument)
78 | }
79 |
80 | @Test
81 | def parameter() = global.ask { () =>
82 | val sel = """
83 | object O{
84 | def fn(/*(*/a: Int/*)*/) = {
85 | a
86 | }
87 | }
88 | """.selection
89 | assertFalse(sel.representsValue)
90 | assertTrue(sel.representsValueDefinitions)
91 | assertTrue(sel.representsParameter)
92 | }
93 |
94 | @Test
95 | def multipleParameters() = global.ask { () =>
96 | val sel = """
97 | object O{
98 | def fn(/*(*/a: Int, b: Int/*)*/) = {
99 | a * b
100 | }
101 | }
102 | """.selection
103 | assertFalse(sel.representsValue)
104 | assertTrue(sel.representsValueDefinitions)
105 | assertTrue(sel.representsParameter)
106 | }
107 |
108 | @Test
109 | def triggersSideEffects() = global.ask { () =>
110 | val sel = """
111 | object O{
112 | var a = 1
113 | /*(*/def fn = {
114 | a += 1
115 | a
116 | }/*)*/
117 | }
118 | """.selection
119 | assertTrue(sel.mayHaveSideEffects)
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/common/SelectionsTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package tests.common
7 |
8 | import tests.util.TestHelper
9 | import org.junit.Assert._
10 | import scala.tools.refactoring.tests.util.TextSelections
11 |
12 | class SelectionsTest extends TestHelper {
13 |
14 | private def getIndexedSelection(src: String) = {
15 | val tree = treeFrom(src)
16 | val textSelection = TextSelections.extractOne(src)
17 | FileSelection(tree.pos.source.file, tree, textSelection.from, textSelection.to)
18 | }
19 |
20 | def selectedLocalVariable(expected: String, src: String) = {
21 |
22 | val selection = getIndexedSelection(src)
23 |
24 | assertEquals(expected, selection.selectedSymbolTree.get.symbol.name.toString)
25 | }
26 |
27 | def assertSelection(expectedTrees: String, expectedSymbols: String, src: String) = {
28 |
29 | val selection = getIndexedSelection(src)
30 |
31 | assertEquals(expectedTrees, selection.allSelectedTrees map (_.getClass.getSimpleName) mkString ", ")
32 | assertEquals(expectedSymbols, selection.selectedSymbols mkString ", ")
33 | }
34 |
35 | @Test
36 | def findValDefInMethod() = {
37 | assertSelection(
38 | "ValDef, Apply, Select, Ident, Ident",
39 | "value b, method +, value a, value i", """
40 | package findValDefInMethod
41 | class A {
42 | def addThree(i: Int) = {
43 | val a = 1
44 | /*(*/ val b = a + i /*)*/
45 | val c = b + 1
46 | c
47 | }
48 | }
49 | """)
50 | }
51 |
52 | @Test
53 | def findIdentInMethod() = {
54 | assertSelection("Ident", "value i", """
55 | package findIdentInMethod
56 | class A {
57 | def addThree(i: Int) = {
58 | val a = 1
59 | val b = a + /*(*/ i /*)*/
60 | val c = b + 1
61 | c
62 | }
63 | }
64 | """)
65 | }
66 |
67 | @Test
68 | def findInMethodArguments() = {
69 | assertSelection("ValDef, TypeTree", "value i", """
70 | package findInMethodArguments
71 | class A {
72 | def addThree(/*(*/ i : Int /*)*/) = {
73 | i
74 | }
75 | }
76 | """)
77 | }
78 |
79 | @Test
80 | def findWholeMethod() = {
81 | assertSelection(
82 | "DefDef, ValDef, TypeTree, Apply, Select, Ident, Literal",
83 | "method addThree, value i, method *, value i", """
84 | package findWholeMethod
85 | class A {
86 | /*(*/
87 | def addThree(i: Int) = {
88 | i * 5
89 | }
90 | /*)*/
91 | }
92 | """)
93 |
94 | }
95 | @Test
96 | def findNothing() = {
97 | assertSelection("", "", """
98 | package findNothing
99 | class A {
100 | /*(*/ /*)*/
101 | def addThree(i: Int) = {
102 | i * 5
103 | }
104 | }
105 | """)
106 | }
107 |
108 | @Test
109 | def findSelectedLocal() = {
110 | selectedLocalVariable("copy", """
111 | package findSelectedLocal
112 | class A {
113 | def times5(i: Int) = {
114 | val /*(*/copy/*)*/ = i
115 | copy * 5
116 | }
117 | }
118 | """)
119 | }
120 |
121 | @Test
122 | def selectedTheFirstCompleteSymbol() = {
123 | selectedLocalVariable("i", """
124 | package selectedTheFirstCompleteSymbol
125 | class A {
126 | def times5(i: Int) = {
127 | val /*(*/copy = i /*)*/
128 | copy * 5
129 | }
130 | }
131 | """)
132 | }
133 |
134 | @Test
135 | def selectedTheFirstSymbol() = {
136 | selectedLocalVariable("copy", """
137 | package selectedTheFirstSymbol
138 | class A {
139 | def times5(i: Int) = {
140 | /*(*/ val copy = i /*)*/
141 | copy * 5
142 | }
143 | }
144 | """)
145 | }
146 | }
147 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/implementations/AddFieldTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package tests.implementations
3 |
4 | import common.Change
5 | import org.junit.Assert.assertEquals
6 | import tests.util.TestHelper
7 | import scala.tools.refactoring.implementations._
8 | import language.reflectiveCalls
9 | import scala.tools.refactoring.util.UniqueNames
10 |
11 | class AddFieldTest extends TestHelper {
12 | outer =>
13 |
14 | def addField(className: String, valName: String, isVar: Boolean, returnType: Option[String], target: AddMethodTarget, src: String, expected: String) = {
15 | global.ask { () =>
16 | val refactoring = new AddField {
17 | val global = outer.global
18 | val file = addToCompiler(UniqueNames.basename(), src)
19 | val change = addField(file, className, valName, isVar, returnType, target)
20 | }
21 | assertEquals(expected, Change.applyChanges(refactoring.change, src))
22 | }
23 | }
24 |
25 | @Test
26 | def addValToObject() = {
27 | addField("Main", "field", isVar = false, Option("Any"), AddToObject, """
28 | class Main
29 | object Main {
30 |
31 | }""",
32 | """
33 | class Main
34 | object Main {
35 | val field: Any = ???
36 | }""")
37 | }
38 |
39 | @Test
40 | def addValToObject2() = {
41 | addField("Main", "*", isVar = false, None, AddToObject, """
42 | class Main
43 | object Main {
44 |
45 | }""",
46 | """
47 | class Main
48 | object Main {
49 | val * = ???
50 | }""")
51 | }
52 |
53 | @Test
54 | def addVarToObject() = {
55 | addField("Main", "field", isVar = true, Option("Any"), AddToObject, """
56 | class Main
57 | object Main {
58 |
59 | }""",
60 | """
61 | class Main
62 | object Main {
63 | var field: Any = ???
64 | }""")
65 | }
66 |
67 | @Test
68 | def addVarToObject2() = {
69 | addField("Main", "*", isVar = true, None, AddToObject, """
70 | class Main
71 | object Main {
72 |
73 | }""",
74 | """
75 | class Main
76 | object Main {
77 | var * = ???
78 | }""")
79 | }
80 |
81 | @Test
82 | def addValToClass() = {
83 | addField("Main", "field", isVar = false, Option("Any"), AddToClass, """
84 | object Main
85 | class Main {
86 | def existingMethod = "this is an existing method"
87 | }""",
88 | """
89 | object Main
90 | class Main {
91 | def existingMethod = "this is an existing method"
92 |
93 | val field: Any = ???
94 | }""")
95 | }
96 |
97 | @Test
98 | def addValToInnerClass() = {
99 | addField("Inner", "field", isVar = false, Option("Any"), AddToClass, """
100 | class Main {
101 | class Inner
102 | }""",
103 | """
104 | class Main {
105 | class Inner {
106 | val field: Any = ???
107 | }
108 | }""")
109 | }
110 |
111 | @Test
112 | def addValToCaseClass() = {
113 | addField("Main", "field", isVar = false, Option("Any"), AddToClass, """
114 | case class Main""",
115 | """
116 | case class Main {
117 | val field: Any = ???
118 | }""")
119 | }
120 |
121 | @Test
122 | def addValToTrait() = {
123 | addField("Main", "field", isVar = false, Option("Any"), AddToClass, """
124 | trait Main""",
125 | """
126 | trait Main {
127 | val field: Any = ???
128 | }""")
129 | }
130 |
131 | @Test
132 | def addValByClosestPosition() = {
133 | addField("Main", "field", isVar = false, Option("Any"), AddToClosest(30), """
134 | class Main {
135 |
136 | }
137 | object Main {
138 |
139 | }""",
140 | """
141 | class Main {
142 |
143 | }
144 | object Main {
145 | val field: Any = ???
146 | }""")
147 | }
148 |
149 | @Test
150 | def addVarByClosestPosition() = {
151 | addField("Main", "field", isVar = true, Option("Any"), AddToClosest(30), """
152 | class Main {
153 |
154 | }
155 | object Main {
156 |
157 | }""",
158 | """
159 | class Main {
160 |
161 | }
162 | object Main {
163 | var field: Any = ???
164 | }""")
165 | }
166 | }
167 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/implementations/ExplicitGettersSettersTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package tests.implementations
7 |
8 | import implementations.ExplicitGettersSetters
9 | import tests.util.TestHelper
10 | import tests.util.TestRefactoring
11 |
12 |
13 | class ExplicitGettersSettersTest extends TestHelper with TestRefactoring {
14 | outer =>
15 |
16 | def explicitGettersSetters(pro: FileSet) = new TestRefactoringImpl(pro) {
17 | val refactoring = new ExplicitGettersSetters {
18 | val global = outer.global
19 | }
20 | val changes = performRefactoring()
21 | }.changes
22 |
23 | @Test
24 | @ScalaVersion(doesNotMatch = "2.12")
25 | def oneVarFromMany() = new FileSet {
26 | """
27 | package oneFromMany
28 | class Demo(val a: String, /*(*/var i: Int/*)*/ ) {
29 | def doNothing = ()
30 | }
31 | """ becomes
32 | """
33 | package oneFromMany
34 | class Demo(val a: String, /*(*/private var _i: Int/*)*/ ) {
35 | def i = {
36 | _i
37 | }
38 |
39 | def i_=(i: Int) = {
40 | _i = i
41 | }
42 | def doNothing = ()
43 | }
44 | """
45 | } applyRefactoring(explicitGettersSetters)
46 |
47 | @Test
48 | @ScalaVersion(doesNotMatch = "2.12")
49 | def oneValFromMany() = new FileSet {
50 | """
51 | package oneFromMany
52 | class Demo(val a: String, /*(*/val i: Int/*)*/ ) {
53 | def doNothing = ()
54 | }
55 | """ becomes
56 | """
57 | package oneFromMany
58 | class Demo(val a: String, /*(*/_i: Int/*)*/ ) {
59 | def i = {
60 | _i
61 | }
62 |
63 | def doNothing = ()
64 | }
65 | """
66 | } applyRefactoring(explicitGettersSetters)
67 |
68 | @Test
69 | @ScalaVersion(doesNotMatch = "2.12")
70 | def singleVal() = new FileSet {
71 | """
72 | package oneFromMany
73 | class Demo( /*(*/val i: Int/*)*/ )
74 | """ becomes
75 | """
76 | package oneFromMany
77 | class Demo( /*(*/_i: Int/*)*/ ) {
78 | def i = {
79 | _i
80 | }
81 | }
82 | """
83 | } applyRefactoring(explicitGettersSetters)
84 |
85 | @Test
86 | @ScalaVersion(doesNotMatch = "2.12")
87 | def singleValWithEmptyBody() = new FileSet {
88 | """
89 | package oneFromMany
90 | class Demo( /*(*/val i: Int/*)*/ ) {
91 |
92 | }
93 | """ becomes
94 | """
95 | package oneFromMany
96 | class Demo( /*(*/_i: Int/*)*/ ) {
97 | def i = {
98 | _i
99 | }
100 | }
101 | """
102 | } applyRefactoring(explicitGettersSetters)
103 | }
104 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/implementations/extraction/ExtractionsTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.implementations.extraction
2 |
3 | import scala.tools.refactoring.tests.util.TestHelper
4 | import scala.tools.refactoring.implementations.extraction.Extractions
5 | import org.junit.Assert._
6 |
7 | class ExtractionsTest extends TestHelper with Extractions {
8 | @Test
9 | def findExtractionTargets() = {
10 | val s = toSelection("""
11 | object O{
12 | def fn = {
13 | val a = 1
14 | /*(*/2 * a/*)*/
15 | }
16 | }
17 | """)
18 |
19 | TestCollector.collect(s)
20 | assertEquals(2, TestCollector.extractionTargets.length)
21 | }
22 |
23 | @Test
24 | def noExtractionTargetsForSyntheticScopes() = {
25 | val s = toSelection("""
26 | object O{
27 | def fn =
28 | 1 :: 2 :: /*(*/Nil/*)*/
29 | }
30 | """)
31 |
32 | TestCollector.collect(s)
33 | assertEquals(2, TestCollector.extractionTargets.length)
34 | }
35 |
36 | @Test
37 | def noExtractionTargetsForCasesWithSelectedPattern() = {
38 | val s = toSelection("""
39 | object O{
40 | 1 match {
41 | case /*(*/i/*)*/ => i
42 | }
43 | }
44 | """)
45 |
46 | TestCollector.collect(s)
47 | assertEquals(1, TestCollector.extractionTargets.length)
48 | }
49 |
50 | object TestCollector extends ExtractionCollector[Extraction] {
51 | var extractionTargets: List[ExtractionTarget] = Nil
52 |
53 | def isValidExtractionSource(s: Selection) = true
54 |
55 | def createExtractions(source: Selection, targets: List[ExtractionTarget], name: String) = {
56 | extractionTargets = targets
57 | Nil
58 | }
59 | }
60 | }
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsBaseTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package tests.implementations.imports
7 |
8 | import implementations.OrganizeImports
9 | import sourcegen.Formatting
10 | import tests.util.TestHelper
11 | import tests.util.TestRefactoring
12 |
13 | abstract class OrganizeImportsBaseTest extends TestHelper with TestRefactoring {
14 |
15 | abstract class OrganizeImportsRefatoring(pro: FileSet, formatting: Formatting = new Formatting{}) extends TestRefactoringImpl(pro) {
16 | val refactoring = new OrganizeImports {
17 | val global = OrganizeImportsBaseTest.this.global
18 | override val dropScalaPackage = formatting.dropScalaPackage
19 | override val lineDelimiter = formatting.lineDelimiter
20 | }
21 | type RefactoringParameters = refactoring.RefactoringParameters
22 | val params: RefactoringParameters
23 | def mkChanges = performRefactoring(params)
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsCollapseSelectorsToWildcardTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.implementations.imports
2 |
3 | import scala.tools.refactoring.implementations.OrganizeImports
4 |
5 | class OrganizeImportsCollapseSelectorsToWildcardTest extends OrganizeImportsBaseTest {
6 |
7 | def organize(exclude: Set[String] = Set())(pro: FileSet) = new OrganizeImportsRefatoring(pro) {
8 | import refactoring._
9 | val maxIndividualImports = 2
10 | val oiConfig = OrganizeImports.OrganizeImportsConfig(
11 | importsStrategy = Some(OrganizeImports.ImportsStrategy.CollapseImports),
12 | collapseToWildcardConfig = Some(OrganizeImports.CollapseToWildcardConfig(maxIndividualImports, exclude)))
13 | val params = new RefactoringParameters(deps = Dependencies.FullyRecompute, config = Some(oiConfig))
14 | }.mkChanges
15 |
16 | @Test
17 | def collapseImportSelectorsToWildcard() = new FileSet {
18 | """
19 | import scala.math.{BigDecimal, BigInt, Numeric}
20 |
21 | object A {
22 | (BigDecimal, BigInt, Numeric)
23 | }""" becomes
24 | """
25 | import scala.math._
26 |
27 | object A {
28 | (BigDecimal, BigInt, Numeric)
29 | }"""
30 | } applyRefactoring organize()
31 |
32 | @Test
33 | def dontCollapseImportsWhenRename() = new FileSet {
34 | """
35 | package acme
36 | import scala.math.{BigDecimal, BigInt, Numeric => N}
37 |
38 | object A {
39 | (BigDecimal, BigInt, N)
40 | }""" isNotModified
41 | } applyRefactoring organize()
42 |
43 | @Test
44 | def dontCollapseWhenCollidingWithExplicitImport() = new FileSet {
45 | """
46 | import scala.collection.immutable.{HashSet, BitSet, HashMap}
47 | import scala.collection.mutable.{ArrayStack, ArrayBuilder, ArrayBuffer}
48 |
49 | object MyObject {
50 | (BitSet, HashMap, HashSet)
51 | (ArrayBuffer, ArrayBuilder, ArrayStack)
52 | }""" becomes
53 | """
54 | import scala.collection.immutable._
55 | import scala.collection.mutable.{ArrayBuffer, ArrayBuilder, ArrayStack}
56 |
57 | object MyObject {
58 | (BitSet, HashMap, HashSet)
59 | (ArrayBuffer, ArrayBuilder, ArrayStack)
60 | }"""
61 | } applyRefactoring organize()
62 |
63 | @Test
64 | def dontCollapseWhenPackageInExcludes() = new FileSet {
65 | val before = """
66 | import scala.collection.immutable.{BitSet, HashMap, HashSet}
67 |
68 | object MyObject {
69 | (BitSet, HashMap, HashSet)
70 | }"""
71 |
72 | before becomes before
73 | } applyRefactoring organize(Set("scala.collection.immutable"))
74 |
75 | }
76 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsEndOfLineTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package tests.implementations.imports
3 |
4 | import scala.tools.refactoring.implementations.OrganizeImports
5 | import scala.tools.refactoring.implementations.OrganizeImports.Dependencies
6 | import scala.tools.refactoring.sourcegen.Formatting
7 |
8 | class OrganizeImportsEndOfLineTest extends OrganizeImportsBaseTest {
9 | private def organizeCustomized(
10 | formatting: Formatting,
11 | groupPkgs: List[String] = List("java", "scala", "org", "com"),
12 | useWildcards: Set[String] = Set("scalaz", "scalaz.Scalaz"),
13 | dependencies: Dependencies.Value = Dependencies.FullyRecompute,
14 | organizeLocalImports: Boolean = true)(pro: FileSet) = new OrganizeImportsRefatoring(pro, formatting) {
15 | val oiConfig = OrganizeImports.OrganizeImportsConfig(
16 | importsStrategy = Some(OrganizeImports.ImportsStrategy.ExpandImports),
17 | wildcards = useWildcards,
18 | groups = groupPkgs)
19 | val params = {
20 | new refactoring.RefactoringParameters(
21 | deps = dependencies,
22 | organizeLocalImports = organizeLocalImports,
23 | config = Some(oiConfig))
24 | }
25 | }.mkChanges
26 |
27 | private def organizeWithUnixEOL(pro: FileSet) = organizeCustomized(formatting = new Formatting { override def lineDelimiter = "\n" })(pro)
28 | private def organizeWithWindowsEOL(pro: FileSet) = organizeCustomized(formatting = new Formatting { override def lineDelimiter = "\r\n" })(pro)
29 |
30 | @Test
31 | def shouldPreserveUnixLineSeparator_v1() = new FileSet {
32 | "package testunix\n\nimport scala.util.Try\nimport java.util.List" becomes
33 | "package testunix\n\n"
34 | } applyRefactoring organizeWithUnixEOL
35 |
36 | @Test
37 | def shouldPreserveUnixLineSeparator_v2() = new FileSet {
38 | "package testunix\n\nimport scala.util.Try\nimport java.util.List\n" becomes
39 | "package testunix\n\n"
40 | } applyRefactoring organizeWithUnixEOL
41 |
42 | @Test
43 | def shouldPreserveUnixLineSeparator_v3() = new FileSet {
44 | "package testunix\n\nimport scala.util.Try\nimport java.util.List\n\nclass A(val t: Try[Int], val l: List[Int])" becomes
45 | "package testunix\n\nimport java.util.List\n\nimport scala.util.Try\n\nclass A(val t: Try[Int], val l: List[Int])"
46 | } applyRefactoring organizeWithUnixEOL
47 |
48 | @Test
49 | def shouldPreserveUnixLineSeparator_v4() = new FileSet {
50 | "package testunix\n\nimport scala.util.Try\nimport scala.util.Either\n\nclass A(val t: Try[Int], val l: Either[Int, Int])" becomes
51 | "package testunix\n\nimport scala.util.Either\nimport scala.util.Try\n\nclass A(val t: Try[Int], val l: Either[Int, Int])"
52 | } applyRefactoring organizeWithUnixEOL
53 |
54 | @Test
55 | def shouldPreserveWindowsLineSeparator_v1() = new FileSet {
56 | "package testwin\r\n\r\nimport scala.util.Try\r\nimport java.util.List" becomes
57 | "package testwin\r\n\r\n"
58 | } applyRefactoring organizeWithWindowsEOL
59 |
60 | @Test
61 | def shouldPreserveWindowsLineSeparator_v2() = new FileSet {
62 | "package testwin\r\n\r\nimport scala.util.Try\r\nimport java.util.List\r\n" becomes
63 | "package testwin\r\n\r\n"
64 | } applyRefactoring organizeWithWindowsEOL
65 |
66 | @Test
67 | def shouldPreserveWindowsLineSeparator_v3() = new FileSet {
68 | "package testwin\r\n\r\nimport scala.util.Try\r\nimport java.util.List\r\n\r\nclass A(val t: Try[Int], val l: List[Int])" becomes
69 | "package testwin\r\n\r\nimport java.util.List\r\n\r\nimport scala.util.Try\r\n\r\nclass A(val t: Try[Int], val l: List[Int])"
70 | } applyRefactoring organizeWithWindowsEOL
71 |
72 | @Test
73 | def shouldPreserveWindowsLineSeparator_v4() = new FileSet {
74 | "package testwin\r\n\r\nimport scala.util.Try\r\nimport scala.util.Either\r\n\r\nclass A(val t: Try[Int], val l: Either[Int, Int])" becomes
75 | "package testwin\r\n\r\nimport scala.util.Either\r\nimport scala.util.Try\r\n\r\nclass A(val t: Try[Int], val l: Either[Int, Int])"
76 | } applyRefactoring organizeWithWindowsEOL
77 | }
78 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsWildcardsTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package tests.implementations.imports
7 |
8 | import scala.tools.refactoring.implementations.OrganizeImports
9 |
10 | class OrganizeImportsWildcardsTest extends OrganizeImportsBaseTest {
11 |
12 | def organize(groups: Set[String])(pro: FileSet) = new OrganizeImportsRefatoring(pro) {
13 | import refactoring._
14 | val oiConfig = OrganizeImports.OrganizeImportsConfig(
15 | importsStrategy = Some(OrganizeImports.ImportsStrategy.PreserveWildcards),
16 | wildcards = groups)
17 | val params = new RefactoringParameters(deps = Dependencies.FullyRecompute, config = Some(oiConfig))
18 | }.mkChanges
19 |
20 | val source = """
21 | import scala.collection.mutable.Set
22 | import org.xml.sax.Attributes
23 | import Set._
24 |
25 | trait Temp {
26 | // we need some code that use the imports
27 | val z: (Attributes, Set[String])
28 | println(apply(""))
29 | }
30 | """
31 |
32 | @Test
33 | def noGrouping() = new FileSet {
34 | source becomes
35 | """
36 | import org.xml.sax.Attributes
37 | import scala.collection.mutable.Set
38 | import scala.collection.mutable.Set.apply
39 |
40 | trait Temp {
41 | // we need some code that use the imports
42 | val z: (Attributes, Set[String])
43 | println(apply(""))
44 | }
45 | """
46 | } applyRefactoring organize(Set())
47 |
48 | @Test
49 | def simpleWildcard() = new FileSet {
50 | source becomes
51 | """
52 | import org.xml.sax.Attributes
53 | import scala.collection.mutable.Set
54 | import scala.collection.mutable.Set._
55 |
56 | trait Temp {
57 | // we need some code that use the imports
58 | val z: (Attributes, Set[String])
59 | println(apply(""))
60 | }
61 | """
62 | } applyRefactoring organize(Set("scala.collection.mutable.Set"))
63 |
64 | @Test
65 | def renamedImport() = new FileSet {
66 | """
67 | import java.lang.Integer.{valueOf => vo}
68 | import java.lang.Integer.toBinaryString
69 | import java.lang.String.valueOf
70 |
71 | trait Temp {
72 | valueOf(5)
73 | vo("5")
74 | toBinaryString(27)
75 | }
76 | """ becomes
77 | """
78 | import java.lang.Integer._
79 | import java.lang.Integer.{valueOf => vo}
80 | import java.lang.String.valueOf
81 |
82 | trait Temp {
83 | valueOf(5)
84 | vo("5")
85 | toBinaryString(27)
86 | }
87 | """
88 | } applyRefactoring organize(Set("java.lang.Integer"))
89 |
90 | @Test
91 | def multipleImportsOneWildcard() = new FileSet {
92 | """
93 | import java.lang.Integer.valueOf
94 | import java.lang.Integer.toBinaryString
95 | import java.lang.Double.toHexString
96 |
97 | trait Temp {
98 | valueOf("5")
99 | toBinaryString(27)
100 | toHexString(5)
101 | }
102 | """ becomes
103 | """
104 | import java.lang.Double.toHexString
105 | import java.lang.Integer._
106 |
107 | trait Temp {
108 | valueOf("5")
109 | toBinaryString(27)
110 | toHexString(5)
111 | }
112 | """
113 | } applyRefactoring organize(Set("java.lang.Integer"))
114 |
115 | }
116 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/implementations/imports/UnusedImportsFinderTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package tests.implementations.imports
7 |
8 | import implementations.UnusedImportsFinder
9 | import tests.util.TestHelper
10 | import scala.tools.refactoring.util.UniqueNames
11 |
12 | class UnusedImportsFinderTest extends TestHelper {
13 | outer =>
14 |
15 | def findUnusedImports(expected: String, src: String): Unit = {
16 |
17 | val unuseds = global.ask { () =>
18 | new UnusedImportsFinder {
19 |
20 | val global = outer.global
21 |
22 | val unit = global.unitOfFile(addToCompiler(UniqueNames.basename(), src))
23 |
24 | def compilationUnitOfFile(f: AbstractFile) = Some(unit)
25 |
26 | val unuseds = findUnusedImports(unit)
27 | }.unuseds
28 | }
29 |
30 | org.junit.Assert.assertEquals(expected, unuseds.mkString(", "))
31 | }
32 |
33 | @Test
34 | def simpleUnusedType() = findUnusedImports(
35 | "(ListBuffer,2)",
36 | """
37 | import scala.collection.mutable.ListBuffer
38 |
39 | object Main {val s: String = "" }
40 | """
41 | )
42 |
43 | @Test
44 | def typeIsUsedAsVal() = findUnusedImports(
45 | "",
46 | """
47 | import scala.collection.mutable.ListBuffer
48 |
49 | object Main {val s = new ListBuffer[Int] }
50 | """
51 | )
52 |
53 | @Test
54 | def typeIsImportedFrom() = findUnusedImports(
55 | "",
56 | """
57 | class Forest {
58 | class Tree
59 | }
60 |
61 | class UsesTrees {
62 | val forest = new Forest
63 | import forest._
64 | val x = new Tree
65 | }
66 | """
67 | )
68 |
69 | @Test
70 | def wildcardImports() = findUnusedImports(
71 | "",
72 | """
73 | import scala.util.control.Exception._
74 |
75 | class UsesTrees {
76 | val plugin = ScalaPlugin.plugin
77 | import plugin._
78 | ()
79 | }
80 | """
81 | )
82 |
83 | @Test
84 | def wildcardImportsFromValsAreIgnored() = findUnusedImports(
85 | "",
86 | """
87 | object ScalaPlugin {
88 | var plugin: String = _
89 | }
90 |
91 | class UsesTrees {
92 | val plugin = ScalaPlugin.plugin
93 | import plugin._
94 | ()
95 | }
96 | """
97 | )
98 |
99 | @Test
100 | def importFromJavaClass() = findUnusedImports(
101 | "",
102 | """
103 | import java.util.Date
104 |
105 | object ScalaPlugin {
106 | import Date._
107 | val max = parse(null)
108 | }
109 | """
110 | )
111 |
112 | @Test
113 | def ignoreImportIsNeverunused() = findUnusedImports(
114 | "",
115 | """
116 | import java.util.{Date => _, _}
117 |
118 | object NoDate {
119 | var x: Stack[Int] = null
120 | }
121 | """
122 | )
123 |
124 | // more tests are in organize imports
125 | }
126 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/sourcegen/CustomFormattingTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package tests.sourcegen
7 |
8 | import tests.util.TestHelper
9 | import org.junit.Assert._
10 | import sourcegen.SourceGenerator
11 | import scala.tools.refactoring.implementations.OrganizeImports
12 | import scala.tools.refactoring.tests.util.TestRefactoring
13 |
14 |
15 | class CustomFormattingTest extends TestHelper with TestRefactoring with SourceGenerator {
16 |
17 | @volatile
18 | private var surroundingImport = ""
19 |
20 | override def spacingAroundMultipleImports = surroundingImport
21 |
22 | abstract class OrganizeImportsRefatoring(pro: FileSet) extends TestRefactoringImpl(pro) {
23 | val refactoring = new OrganizeImports {
24 | val global = CustomFormattingTest.this.global
25 | override def spacingAroundMultipleImports = surroundingImport
26 | }
27 | type RefactoringParameters = refactoring.RefactoringParameters
28 | val params: RefactoringParameters
29 | def mkChanges = performRefactoring(params)
30 | }
31 |
32 | def organize(pro: FileSet) = new OrganizeImportsRefatoring(pro) {
33 | val config = OrganizeImports.OrganizeImportsConfig(
34 | importsStrategy = Some(OrganizeImports.ImportsStrategy.CollapseImports)
35 | )
36 | val params = new RefactoringParameters(config = Some(config))
37 | }.mkChanges
38 |
39 |
40 | @Test
41 | def testSingleSpace(): Unit = {
42 |
43 | val ast = treeFrom("""
44 | package test
45 | import scala.collection.{MapLike, MapProxy}
46 | """)
47 |
48 | surroundingImport = " "
49 |
50 | assertEquals("""
51 | package test
52 | import scala.collection.{ MapLike, MapProxy }
53 | """, createText(ast, Some(ast.pos.source)))
54 | }
55 |
56 | @Test
57 | def collapse() = {
58 | surroundingImport = " "
59 |
60 | new FileSet {
61 | """
62 | import java.lang.String
63 | import java.lang.Object
64 |
65 | object Main {val s: String = ""; var o: Object = null}
66 | """ becomes
67 | """
68 | import java.lang.{ Object, String }
69 |
70 | object Main {val s: String = ""; var o: Object = null}
71 | """
72 | } applyRefactoring organize
73 | }
74 | }
75 |
76 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/sourcegen/LayoutTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package tests.sourcegen
7 |
8 | import org.junit.Test
9 | import org.junit.Assert._
10 | import sourcegen._
11 |
12 | import language.postfixOps
13 |
14 | class LayoutTest {
15 |
16 | @Test
17 | def simpleConcatenation(): Unit = {
18 | assertEquals("ab", Fragment("a") ++ Fragment("b") asText)
19 | assertEquals("abc", Fragment("a") ++ Fragment("b") ++ Fragment("c") asText)
20 | assertEquals("ab", Layout("a") ++ Layout("b") asText)
21 | assertEquals("abc", Layout("a") ++ Fragment("b") ++ Layout("c") asText)
22 | assertEquals("abc", Fragment("a") ++ Layout("b") ++ Fragment("c") asText)
23 | }
24 |
25 | @Test
26 | def concatenationsWithEmpty(): Unit = {
27 | val N = NoLayout
28 | val F = EmptyFragment
29 |
30 | assertEquals("", N asText)
31 | assertEquals("", F asText)
32 | assertEquals("", N ++ F asText)
33 | assertEquals("", N ++ N asText)
34 | assertEquals("", F ++ N asText)
35 | assertEquals("", F ++ F asText)
36 |
37 | assertEquals("a", Fragment("a") ++ F asText)
38 | assertEquals("a", Fragment("a") ++ N asText)
39 |
40 | assertEquals("a", F ++ Fragment("a") asText)
41 | assertEquals("a", N ++ Fragment("a") asText)
42 |
43 | assertEquals("ab", Fragment("a") ++ F ++ Layout("b") asText)
44 | assertEquals("ab", Layout("a") ++ N ++ Fragment("b")asText)
45 | }
46 |
47 | @Test
48 | def complexConcatenations(): Unit = {
49 | val a = Layout("a")
50 | val b = Layout("b")
51 | val c = Layout("c")
52 |
53 | (Fragment(a, b, c) ++ Fragment(a, b, c)) match {
54 | case Fragment(a, b, c) =>
55 | assertEquals("a", a.asText)
56 | assertEquals("bcab", b.asText)
57 | assertEquals("c", c.asText)
58 | }
59 |
60 | (Fragment(a, b, c) ++ a) match {
61 | case Fragment(a, b, c) =>
62 | assertEquals("a", a.asText)
63 | assertEquals("b", b.asText)
64 | assertEquals("ca", c.asText)
65 | }
66 |
67 | (b ++ Fragment(a, b, c)) match {
68 | case Fragment(a, b, c) =>
69 | assertEquals("ba", a.asText)
70 | assertEquals("b", b.asText)
71 | assertEquals("c", c.asText)
72 | }
73 |
74 | (Fragment("a") ++ Fragment("b")) match {
75 | case Fragment(a, b, c) =>
76 | assertEquals("", a.asText)
77 | assertEquals("ab", b.asText)
78 | assertEquals("", c.asText)
79 | }
80 | }
81 |
82 | @Test
83 | def preserveRequisites(): Unit = {
84 | val r = Requisite.allowSurroundingWhitespace(",")
85 | val a = Fragment("a")
86 | val b = Fragment("b")
87 | val x = Layout("x")
88 |
89 | assertEquals("a,x", a ++ r ++ x asText)
90 | assertEquals("a,b", a ++ r ++ b asText)
91 | }
92 |
93 | @Test
94 | def requisitesAreBetween(): Unit = {
95 | val r = Requisite.allowSurroundingWhitespace(",")
96 | val a = Fragment(Layout("a"), Layout("b"), Layout("c"))
97 | val b = Fragment(Layout("x"), Layout("y"), Layout("z"))
98 |
99 | assertEquals("abc,xyz", a ++ r ++ b asText)
100 | }
101 |
102 | @Test
103 | def requisitesAreOnlyUsesWhenNeeded1(): Unit = {
104 | val r = Requisite.allowSurroundingWhitespace(",")
105 | val a = Fragment(Layout("a"), Layout("b"), Layout(","))
106 | val b = Fragment(Layout("x"), Layout("y"), Layout("z"))
107 |
108 | assertEquals("ab,xyz", a ++ r ++ b asText)
109 | }
110 |
111 | @Test
112 | def requisitesAreOnlyUsesWhenNeeded2(): Unit = {
113 | val r = Requisite.allowSurroundingWhitespace(",")
114 | val a = Fragment(Layout("a"), Layout("b"), Layout("c"))
115 | val b = Fragment(Layout(","), Layout("y"), Layout("z"))
116 |
117 | assertEquals("abc,yzabc", a ++ r ++ b ++ a asText)
118 | }
119 | }
120 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/sourcegen/SourceHelperTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package tests.sourcegen
7 |
8 | import tests.util.TestHelper
9 | import org.junit.Assert._
10 | import sourcegen.SourceUtils
11 |
12 | class SourceHelperTest extends TestHelper {
13 |
14 | import SourceUtils._
15 |
16 | @Test
17 | def liftSingleLineComment(): Unit = {
18 |
19 | assertEquals(("abc ", " //x"), splitComment("abc//x"))
20 |
21 | assertEquals(("x x", " /**/ "), splitComment("x/**/x"))
22 |
23 | assertEquals(("5 *5", " /**/ "), splitComment("5/**/*5"))
24 |
25 | assertEquals(("5 *5", " /*/**/*/ "), splitComment("5/*/**/*/*5"))
26 |
27 | assertEquals(("4 /2", " /*/**/*/ "), splitComment("4/*/**/*//2"))
28 | }
29 |
30 | @Test
31 | def multiplication() = {
32 | assertEquals("""
33 | object A {
34 | val r = 3
35 | val p = r * r
36 | }""", stripComment("""
37 | object A {
38 | val r = 3
39 | val p = r/**/* r
40 | }"""))
41 | }
42 |
43 | @Test
44 | def stripCommentInClass() = {
45 | assertEquals(stripWhitespacePreservers("""
46 | class A {
47 | def extractFrom(): Int = {
48 | val a = 1
49 | a + 1 ▒
50 | }
51 | }"""), stripComment("""
52 | class A {
53 | def extractFrom(): Int = {
54 | val a = 1
55 | /*(*/ a + 1/*)*/
56 | }
57 | }"""))
58 | }
59 | }
60 |
61 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/transformation/TransformableSelectionTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.transformation
2 |
3 | import scala.tools.refactoring.tests.util.TestHelper
4 | import org.junit.Assert._
5 | import scala.tools.refactoring.transformation.TransformableSelections
6 | import scala.tools.refactoring.tests.util.TextSelections
7 |
8 | class TransformableSelectionTest extends TestHelper with TransformableSelections {
9 | import global._
10 |
11 | val t123 = Literal(Constant(123))
12 | val tprint123 =
13 | "object O { /*(*/println(123)/*)*/ }".selection.selectedTopLevelTrees.head
14 |
15 | implicit class StringToSel(src: String) {
16 | lazy val root = treeFrom(src)
17 | lazy val selection = {
18 | val textSelection = TextSelections.extractOne(src)
19 | FileSelection(root.pos.source.file, root, textSelection.from, textSelection.to)
20 | }
21 |
22 | def assertReplacement(mkTrans: Selection => Transformation[Tree, Tree]) = {
23 | val trans = mkTrans(selection)
24 | val result = trans(root)
25 |
26 | new {
27 | def toFail() =
28 | assertTrue(result.isEmpty)
29 |
30 | def toBecome(expectedSrc: String) = {
31 | val (expected, actual) = global.ask { () =>
32 | (expectedSrc.root.toString(), result.get.toString())
33 | }
34 | assertEquals(expected, actual)
35 | }
36 |
37 | def toBecomeTreeWith(assertion: Tree => Unit) = {
38 | assertion(result.get)
39 | }
40 | }
41 | }
42 | }
43 |
44 | @Test
45 | def replaceSingleStatement() = global.ask { () => """
46 | object O{
47 | def f = /*(*/1/*)*/
48 | }
49 | """.assertReplacement(_.replaceBy(t123)).toBecome("""
50 | object O{
51 | def f = 123
52 | }
53 | """)
54 | }
55 |
56 | @Test
57 | def replaceSingleStatementInArgument() = global.ask { () => """
58 | object O{
59 | println(/*(*/1/*)*/)
60 | }
61 | """.assertReplacement(_.replaceBy(t123)).toBecome("""
62 | object O{
63 | println(123)
64 | }
65 | """)
66 | }
67 |
68 | @Test
69 | def replaceSequence() = global.ask { () => """
70 | object O{
71 | def f = {
72 | /*(*/println(1)
73 | println(2)/*)*/
74 | println(3)
75 | }
76 | }
77 | """.assertReplacement(_.replaceBy(t123)).toBecome("""
78 | object O{
79 | def f = {
80 | 123
81 | println(3)
82 | }
83 | }
84 | """)
85 | }
86 |
87 | @Test
88 | def replaceAllExpressionsInBlock() = global.ask { () => """
89 | object O{
90 | def f = {
91 | /*(*/println(1)
92 | println(2)
93 | println(3)/*)*/
94 | }
95 | }
96 | """.assertReplacement(_.replaceBy(tprint123)).toBecome("""
97 | object O{
98 | def f = println(123)
99 | }
100 | """)
101 | }
102 |
103 | @Test
104 | def replaceAllExpressionsInBlockPreservingHierarchy() = global.ask { () => """
105 | object O{
106 | def f = {
107 | /*(*/println(1)
108 | println(2)
109 | println(3)/*)*/
110 | }
111 | }
112 | """.assertReplacement(_.replaceBy(tprint123, preserveHierarchy = true)).toBecomeTreeWith { t =>
113 | val preservedBlock = t.find {
114 | // the new block must have an empty tree as its last expression
115 | case Block(stats, EmptyTree) => true
116 | case _ => false
117 | }
118 | assertTrue(preservedBlock.isDefined)
119 | }
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/util/FreshCompilerForeachTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring.tests.util
6 |
7 | import org.junit.After
8 | import scala.tools.refactoring.util.CompilerInstance
9 |
10 | trait FreshCompilerForeachTest extends TestHelper {
11 |
12 | // We are experiencing instable test runs, maybe it helps when we
13 | // use a fresh compiler for each test case:
14 | override val global = (new CompilerInstance).compiler
15 |
16 | @After
17 | def shutdownCompiler(): Unit = {
18 | global.askShutdown
19 | }
20 | }
21 |
22 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/util/SourceHelpersTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.util
2 |
3 | import org.junit.Test
4 | import scala.tools.refactoring.util.SourceHelpers
5 | import org.junit.Assert._
6 | import scala.tools.refactoring.util.SourceWithSelection
7 |
8 | class SourceHelpersTest {
9 | @Test
10 | def isRangeWithinWithTrivialArgs(): Unit = {
11 | testIsRangeWithin("", SourceWithSelection("x", 0, 1), false)
12 | testIsRangeWithin("x", SourceWithSelection("x", 0, 1), true)
13 | testIsRangeWithin("xx", SourceWithSelection("x", 0, 1), false)
14 | }
15 |
16 | @Test
17 | def isRangeWithEmptySelections(): Unit = {
18 | testIsRangeWithin("", SourceWithSelection("", 0, 0), true)
19 | testIsRangeWithin("a", SourceWithSelection("a", 0, 0), true)
20 | testIsRangeWithin("b", SourceWithSelection("a", 0, 0), false)
21 | }
22 |
23 | @Test
24 | def isRangeWithinWithSimpleExamples(): Unit = {
25 | testIsRangeWithin("abab", SourceWithSelection("0abab5", 4, 5), true)
26 | testIsRangeWithin("abab", SourceWithSelection("0abab5", 1, 5), true)
27 | testIsRangeWithin("abab", SourceWithSelection("0abab5", 1, 6), false)
28 | testIsRangeWithin("abab", SourceWithSelection("0abab5", 2, 4), true)
29 | testIsRangeWithin("ab", SourceWithSelection("012a", 3, 4), false)
30 | testIsRangeWithin("abba", SourceWithSelection("012abba", 3, 4), true)
31 | }
32 |
33 | private def testIsRangeWithin(text: String, selection: SourceWithSelection, expected: Boolean): Unit = {
34 | val result = SourceHelpers.isRangeWithin(text, selection)
35 | assertEquals(expected, result)
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/util/TestRefactoring.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2005-2010 LAMP/EPFL
3 | */
4 |
5 | package scala.tools.refactoring
6 | package tests.util
7 |
8 | import common.Change
9 | import common.InteractiveScalaCompiler
10 | import scala.tools.refactoring.common.InteractiveScalaCompiler
11 |
12 | trait TestRefactoring extends TestHelper {
13 |
14 | class PreparationException(cause: String) extends Exception(cause)
15 | class RefactoringException(cause: String) extends Exception(cause)
16 |
17 | abstract class TestRefactoringImpl(project: FileSet) {
18 |
19 | trait TestProjectIndex extends GlobalIndexes {
20 | this: Refactoring =>
21 |
22 | val global = TestRefactoring.this.global
23 |
24 | lazy val trees = {
25 | project.javaSources.foreach { case (code, filename) => addToCompiler(filename, code) }
26 | project.sources.map { case(code, filename) => addToCompiler(filename, code) }.map(global.unitOfFile(_).body)
27 | }
28 |
29 | override val index = global.ask { () =>
30 | val cuIndexes = trees map (_.pos.source.file) map { file =>
31 | global.unitOfFile(file).body
32 | } map CompilationUnitIndex.apply
33 | GlobalIndex(cuIndexes)
34 | }
35 | }
36 |
37 | val refactoring: MultiStageRefactoring with InteractiveScalaCompiler
38 |
39 | def preparationResult() = global.ask { () =>
40 | refactoring.prepare(selection(refactoring, project))
41 | }
42 |
43 | def preparationResult(selection: => refactoring.Selection) = global.ask { () =>
44 | refactoring.prepare(selection)
45 | }
46 |
47 | def performRefactoring(): List[Change] = {
48 | performRefactoring(selection(refactoring, project), null.asInstanceOf[refactoring.RefactoringParameters])
49 | }
50 |
51 | def performRefactoring(parameters: refactoring.RefactoringParameters): List[Change] = {
52 | performRefactoring(selection(refactoring, project), parameters)
53 | }
54 |
55 | def performRefactoring(
56 | selection: => refactoring.Selection,
57 | parameters: refactoring.RefactoringParameters): List[Change] = global.ask { () =>
58 | preparationResult(selection) match {
59 | case Right(prepare) =>
60 | refactoring.perform(selection, prepare, parameters) match {
61 | case Right(modifications) => modifications
62 | case Left(error) => throw new RefactoringException(error.cause)
63 | }
64 | case Left(error) => throw new PreparationException(error.cause)
65 | }
66 | }
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/util/TextSelectionsTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring.tests.util
2 |
3 | import org.junit.Test
4 | import org.junit.Assert.assertEquals
5 | import TextSelections.Range
6 |
7 |
8 | class TextSelectionsTest {
9 | @Test(expected = classOf[IllegalArgumentException])
10 | def testWithEmptyString(): Unit = {
11 | TextSelections.extractOne("")
12 | }
13 |
14 | @Test(expected = classOf[IllegalArgumentException])
15 | def testWithMoreThanOneSelection(): Unit = {
16 | TextSelections.extractOne("/*(*/ /*)*/ /*<-*/")
17 | }
18 |
19 | @Test
20 | def testWithOneEmptySelection(): Unit = {
21 | assertEquals(Range(0, 0), TextSelections.extractOne("/*<-*/"))
22 | assertEquals(Range(5, 5), TextSelections.extractOne("/*(*//*)*/"))
23 | }
24 |
25 | @Test
26 | def testWithOneNormalSelection(): Unit = {
27 | assertEquals(Range(5, 6), TextSelections.extractOne("/*(*/ /*)*/)"))
28 | }
29 |
30 | @Test
31 | def testWithOneSetCursorFromRightSelection(): Unit = {
32 | assertEquals(Range(5, 6), TextSelections.extractOne("012345/*<-cursor*/"))
33 | assertEquals(Range(5, 6), TextSelections.extractOne("012345/*<-cursor-0*/"))
34 | assertEquals(Range(4, 5), TextSelections.extractOne("012345/*<-cursor-1*/"))
35 | assertEquals(Range(3, 4), TextSelections.extractOne("012345/*<-cursor-2*/"))
36 | }
37 |
38 | @Test(expected = classOf[IllegalArgumentException])
39 | def testWithCursorInvalidCursorFromRightSelection(): Unit = {
40 | TextSelections.extractOne("0123/*<-cursor-18*/")
41 | }
42 |
43 | @Test
44 | def testWithOneSetCursorFromLeftSelection(): Unit = {
45 | assertEquals(Range(14, 15), TextSelections.extractOne(" /*cursor->*/45678"))
46 | assertEquals(Range(14, 15), TextSelections.extractOne("/*0-cursor->*/45678"))
47 | assertEquals(Range(15, 16), TextSelections.extractOne("/*1-cursor->*/45678"))
48 | assertEquals(Range(16, 17), TextSelections.extractOne("/*2-cursor->*/45678"))
49 | }
50 |
51 | @Test(expected = classOf[IllegalArgumentException])
52 | def testWithCursorInvalidCursorFromLeftSelection(): Unit = {
53 | TextSelections.extractOne("/*cursor->*/")
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/util/UnionFindInitTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package tests.util
3 |
4 | import scala.util.Random
5 | import scala.tools.refactoring.util.UnionFind
6 | import org.junit.Test
7 | import org.junit.Assert._
8 |
9 | class UnionFindInitTest {
10 |
11 | // We test this on 100 randomly colored Nodes
12 | val colorString = Array("Red", "Blue", "Green", "Yellow", "Blue")
13 | class Node(val color: Int){override def toString() = this.hashCode().toString() + "( " + colorString(color) + ")"}
14 | val testNodes: List[Node] = List.fill(100){new Node(Random.nextInt(5))}
15 | val uf = new UnionFind[this.Node]()
16 |
17 | @Test
18 | def firstInsertedNodesShouldBeTheirOwnParents() = {
19 | for (node <- testNodes) uf.find(node)
20 | // inserted Nodes are their Parents
21 | assertTrue(testNodes.forall{(x) => uf.find(x) == x})
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/src/test/scala/scala/tools/refactoring/tests/util/UnionFindTest.scala:
--------------------------------------------------------------------------------
1 | package scala.tools.refactoring
2 | package tests.util
3 |
4 | import scala.util.Random
5 | import scala.tools.refactoring.util.UnionFind
6 | import org.junit.Test
7 | import org.junit.Assert._
8 | import org.junit.Before
9 |
10 | class UnionFindTest {
11 |
12 | // We test this on 100 randomly colored Nodes
13 | val colorString = Array("Red", "Blue", "Green", "Yellow", "Blue")
14 | class Node(val color: Int){override def toString() = this.hashCode().toString() + "( " + colorString(color) + ")"}
15 | val testNodes: List[Node] = List.fill(100){new Node(Random.nextInt(5))}
16 | val uf = new UnionFind[this.Node]()
17 |
18 | @Before
19 | def unknownNodesShouldNotThrowWhenUnited() = {
20 | for (node1 <- testNodes;
21 | node2 <- testNodes if node1.color == node2.color) uf.union(node1, node2)
22 | }
23 |
24 | @Test
25 | def atMostFiveRepsInUF() = {
26 | val nodesFromUF = testNodes.map(uf.find(_)).distinct
27 | val repsLength = nodesFromUF.length
28 | assertTrue(s"Expected five reps in the union-find, found $repsLength !", repsLength <= 5)
29 | }
30 |
31 | @Test
32 | def nodesInRelationIfAndOnlyIfWithSameColor(): Unit ={
33 | def sameColorImpliesRelation(x: Node, y: Node) = x.color != y.color || uf.find(x) == uf.find(y)
34 | def relationImpliesSameColor(x: Node, y: Node) = uf.find(x) != uf.find(y) || x.color == y.color
35 | for (x <- testNodes;
36 | y <- testNodes) {
37 | val px = uf.find(x)
38 | val py = uf.find(y)
39 | assertTrue(s"problem found with node $x (parent $px) and $y (parent $py)", sameColorImpliesRelation(x, y) && relationImpliesSameColor(x,y))
40 | }
41 | }
42 |
43 | def colorRepresentant(): Array[Node] = {
44 | val representants = new Array[Node](5)
45 | for (c <- 0 to 4){
46 | representants(c) = testNodes.collectFirst{ case (x: Node) if (x.color == c) => uf.find(x)}.getOrElse(uf.find(new Node(c)))
47 | }
48 | representants
49 | }
50 |
51 | @Test
52 | def classRepresentantIsUnique(): Unit ={
53 | val reps = colorRepresentant()
54 | testNodes.foreach{(x) => {
55 | val xColor = colorString(x.color)
56 | val xColorRep = reps(x.color)
57 | assertTrue(s"problem found with $x yet the representant of $xColor is $xColorRep", uf.find(x) == xColorRep)}
58 | }
59 | }
60 |
61 | @Test
62 | def findIsIdempotent(): Unit ={
63 | assertTrue(testNodes.forall{(x) => val p = uf.find(x); p == uf.find(p)})
64 | }
65 |
66 | @Test
67 | def nodesForWhichFindIsIdentityAreReps(): Unit ={
68 | val selfRepresented = testNodes.filter{ (n)=> uf.find(n) == n }
69 | val reps = colorRepresentant()
70 | assertTrue(reps.forall{(x) => selfRepresented.contains(x)})
71 | assertTrue(selfRepresented.forall{(x) => reps.contains(x)})
72 | }
73 |
74 | @Test
75 | def equivalenceClassGivesAColor(): Unit ={
76 | def myClassIsExactlyMyColor(n: Node): Boolean = {
77 | val myClass = uf.equivalenceClass(n)
78 | val inClassImpliesSameColor = testNodes.forall{ (x) => !myClass.contains(x) || x.color == n.color}
79 | val sameColorImpliesInclass = testNodes.forall{ (x) => x.color != n.color || myClass.contains(x)}
80 | inClassImpliesSameColor && sameColorImpliesInclass
81 | }
82 | // A bit overkill to do this on more than representants
83 | assertTrue(colorRepresentant().forall{myClassIsExactlyMyColor})
84 | }
85 |
86 | }
87 |
--------------------------------------------------------------------------------